diff --git a/dex.cabal b/dex.cabal index 1b6b7b771..50703eeb6 100644 --- a/dex.cabal +++ b/dex.cabal @@ -48,7 +48,6 @@ library , Algebra , Builder , CUDA - , Cat , CheapReduction , CheckType , ConcreteSyntax @@ -58,6 +57,7 @@ library , Generalize , Imp , ImpToLLVM + , IncState , Inference , Inline , IRVariants @@ -70,13 +70,14 @@ library , LLVM.Shims , Lexing , Linearize - , Logging , Lower + , MonadUtil , MTL1 , Name , Occurrence , OccAnalysis , Optimize + , PeepholeOptimize , PPrint , RawName , Runtime @@ -84,17 +85,16 @@ library , Serialize , Simplify , Subst - , SourceInfo , SourceRename + , SourceIdTraversal , TopLevel , Transpose - , TraverseSourceInfo , Types.Core , Types.Imp - , Types.Misc , Types.Primitives , Types.OpNames , Types.Source + , Types.Top , QueryType , QueryTypePure , Util @@ -102,7 +102,6 @@ library if flag(live) exposed-modules: Actor , Live.Eval - , Live.Terminal , Live.Web , RenderHtml other-modules: Paths_dex @@ -126,7 +125,6 @@ library , prettyprinter , text -- Portable system utilities - , ansi-terminal , directory , filepath , haskeline @@ -135,11 +133,13 @@ library -- Serialization , aeson , store + , time -- Floating-point pedanticness (correcting for GHC < 9.2.2) , floating-bits if flag(live) build-depends: binary , blaze-html + , blaze-markup , cmark , http-types , wai @@ -234,6 +234,7 @@ executable dex main-is: dex.hs build-depends: dex , ansi-wl-pprint + , ansi-terminal , base , bytestring , containers diff --git a/examples/raytrace.dx b/examples/raytrace.dx index eccfa5695..993d6203c 100644 --- a/examples/raytrace.dx +++ b/examples/raytrace.dx @@ -15,7 +15,7 @@ def Vec(n:Nat) -> Type = Fin n => Float def Mat(n:Nat, m:Nat) -> Type = Fin n => Fin m => Float def relu(x:Float) -> Float = max x 0.0 -def length(x: d=>Float) -> Float given (d|Ix) = sqrt $ sum for i. sq x[i] +def length(x: d=>Float) -> Float given (d|Ix) = sqrt $ sum for i:d. sq x[i] -- TODO: make a newtype for normal vectors def normalize(x: d=>Float) -> d=>Float given (d|Ix) = x / (length x) def directionAndLength(x: d=>Float) -> (d=>Float, Float) given (d|Ix) = @@ -68,7 +68,7 @@ def rotateZ(p:Vec 3, angle:Angle) -> Vec 3 = [c*px - s*py, s*px+c*py, pz] def sampleCosineWeightedHemisphere(normal: Vec 3, k:Key) -> Vec 3 = - [k1, k2] = split_key k + [k1, k2] = split_key(n=2, k) u1 = rand k1 u2 = rand k2 uu = normalize $ cross normal [0.0, 1.1, 1.1] @@ -152,21 +152,21 @@ def sdObject(pos:Position, obj:Object) -> Distance = Wall(nor, d) -> d + dot nor pos Block(blockPos, halfWidths, angle) -> pos' = rotateY (pos - blockPos) angle - length $ for i. max ((abs pos'[i]) - halfWidths[i]) 0.0 + length $ for i:(Fin 3). max ((abs pos'[i]) - halfWidths[i]) 0.0 Sphere(spherePos, r) -> pos' = pos - spherePos max (length pos' - r) 0.0 Light(squarePos, hw, _) -> pos' = pos - squarePos halfWidths = [hw, 0.01, hw] - length $ for i. max ((abs pos'[i]) - halfWidths[i]) 0.0 + length $ for i:(Fin 3). max ((abs pos'[i]) - halfWidths[i]) 0.0 def sdScene(scene:Scene n, pos:Position) -> (Object, Distance) given (n|Ix) = - (i, d) = minimum_by snd $ for i. (i, sdObject pos scene.objects[i]) + (i, d) = minimum_by(for i:n. (i, sdObject pos scene.objects[i]), snd) (scene.objects[i], d) def calcNormal(obj:Object, pos:Position) -> Direction = - normalize (grad (\pos. sdObject pos obj) pos) + grad(\p:Position. sdObject(p, obj)) pos | normalize data RayMarchResult = -- incident ray, surface normal, surface properties @@ -176,7 +176,7 @@ data RayMarchResult = HitNothing def raymarch(scene:Scene n, ray:Ray) -> RayMarchResult given (n|Ix) = - maxIters = 100 + maxIters : Nat = 100 tol = 0.01 startLength = 10.0 * tol -- trying to escape the current surface with_state (10.0 * tol) \rayLength. @@ -209,7 +209,7 @@ def rayDirectRadiance(scene:Scene n, ray:Ray) -> Radiance given (n|Ix) = HitObj(_, _) -> zero def sampleSquare(hw:Float, k:Key) -> Position = - [kx, kz] = split_key k + [kx, kz] : Fin 2 => Key = split_key k x = randuniform (- hw) hw kx z = randuniform (- hw) hw kz [x, 0.0, z] @@ -220,7 +220,7 @@ def sampleLightRadiance( inRay:Ray, k:Key) -> Radiance given (n|Ix) = yield_accum (AddMonoid Float) \radiance. - for i. case scene.objects[i] of + each scene.objects \obj. case obj of PassiveObject(_, _) -> () Light(lightPos, hw, _) -> (dirToLight, distToLight) = directionAndLength $ @@ -244,7 +244,7 @@ def trace(params:Params, scene:Scene n, initRay:Ray, k:Key) -> Color given (n|Ix if i == 0 then radiance += intensity -- TODO: scale etc Done () HitObj(incidentRay, osurf) -> - [k1, k2] = split_key $ hash k i + [k1, k2] = split_key(n=2, hash k i) lightRadiance = sampleLightRadiance scene osurf incidentRay k1 ray := sampleReflection osurf incidentRay k2 filter := surfaceFilter (get filter) osurf.surface @@ -265,27 +265,24 @@ def cameraRays(n:Nat, camera:Camera) -> Fin n => Fin n => ((Key) -> Ray) = pixHalfWidth = halfWidth / n_to_f n ys = reverse $ linspace (Fin n) (neg halfWidth) halfWidth xs = linspace (Fin n) (neg halfWidth) halfWidth - for i j. \key. - [kx, ky] = split_key key + for i:(Fin n) j:(Fin n). \key. + [kx, ky] = split_key(n=2, key) x = xs[j] + randuniform (-pixHalfWidth) pixHalfWidth kx y = ys[i] + randuniform (-pixHalfWidth) pixHalfWidth ky Ray(camera.pos, normalize [x, y, neg camera.sensorDist]) def takePicture(params:Params, scene:Scene m, camera:Camera) -> Image given (m|Ix) = - n = camera.numPix - rays = cameraRays n camera + rays = cameraRays camera.numPix camera rootKey = new_key 0 - image = for i j. + image = for i:(Fin camera.numPix) j:(Fin camera.numPix). pixKey = if params.shareSeed then rootKey else ixkey (ixkey rootKey i) j def sampleRayColor(k:Key) -> Color = - [k1, k2] = split_key k + [k1, k2] = split_key(n=2, k) trace params scene (rays[i,j] k1) k2 sampleAveraged sampleRayColor params.numSamples pixKey - MkImage _ _ $ image / mean (for ixs. - (i,j,k) = ixs - image[i,j,k]) + MkImage _ _ $ image / mean(flatten3D(image)) '## Define the scene and render it diff --git a/lib/diagram.dx b/lib/diagram.dx index 66beff1ff..46fffd400 100644 --- a/lib/diagram.dx +++ b/lib/diagram.dx @@ -35,15 +35,16 @@ struct GeomStyle = default_geom_style = GeomStyle Nothing (Just black) 1 -- TODO: consider sharing attributes among a set of objects for efficiency +Object : Type = (GeomStyle, Point, Geom) struct Diagram = - val : (List (GeomStyle, Point, Geom)) + val : (List Object) instance Monoid(Diagram) mempty = Diagram mempty def (<>)(d1, d2) = Diagram $ d1.val <> d2.val def concat_diagrams(diagrams:n=>Diagram) -> Diagram given (n|Ix) = - Diagram $ concat for i. diagrams[i].val + Diagram $ concat $ each diagrams \d. d.val -- TODO: arbitrary affine transformations. Our current representation of -- rectangles and circles means we can only do scale/flip/rotate90. @@ -54,8 +55,8 @@ def apply_transformation( d:Diagram ) -> Diagram = AsList(_, objs) = d.val - Diagram $ to_list for i. - (attr, p, geom) = objs[i] + Diagram $ to_list $ each objs \obj. + (attr, p, geom) = obj (attr, transformPoint p, transformGeom geom) def flip_y(d:Diagram) -> Diagram = @@ -92,8 +93,8 @@ def text(x:String) -> Diagram = singleton_default $ Text x def update_geom(update: (GeomStyle) -> GeomStyle, d:Diagram) -> Diagram = AsList(_, objs) = d.val - Diagram $ to_list for i. - ( attr, point, geoms) = objs[i] + Diagram $ to_list $ each objs \obj. + ( attr, point, geoms) = obj (update attr, point, geoms) -- TODO: these would be better if we had field-access-based ref projections, so we could @@ -149,7 +150,7 @@ def (<=>)(attr:String, val:b) -> String given (b|Show) = attr <.> "=" <.> quote (show val) def html_color(cs:HtmlColor) -> String = - "#" <> (concat $ for i. showHex cs[i]) + "#" <> (concat $ each cs showHex) def optional_html_color(c: Maybe HtmlColor) -> String = case c of @@ -166,7 +167,7 @@ def attr_string(attr:GeomStyle) -> String = def render_geom(attr:GeomStyle, p:Point, geom:Geom) -> String = -- For things that are solid. SVG says they have fill=stroke. solidAttr = GeomStyle attr.strokeColor attr.strokeColor attr.strokeWidth - groupEle = \attr s. tag_brackets_attr "g" (attr_string attr) s + groupEle = \attr:GeomStyle s:String. tag_brackets_attr "g" (attr_string attr) s case geom of PointGeom -> groupEle solidAttr $ self_closing_brackets $ @@ -188,7 +189,7 @@ def render_geom(attr:GeomStyle, p:Point, geom:Geom) -> String = "x" <=> (p.x - (w/2.0)) <.> "y" <=> (p.y - (h/2.0))) Text content -> - textEle = \s. tag_brackets_attr("text", + textEle = \s:String. tag_brackets_attr("text", ("x" <=> p.x <.> "y" <=> p.y <.> "text-anchor" <=> "middle" <.> -- horizontal center @@ -200,8 +201,8 @@ BoundingBox : Type = (Point, Point) @noinline def compute_bounds(d:Diagram) -> BoundingBox = - computeSubBound = \sel op. - \triple. + computeSubBound = \sel:((Point) -> Float) op:((Float) -> Float). + \triple:Object. (_, p, geom) = triple sel p + case geom of PointGeom -> 0.0 @@ -213,12 +214,12 @@ def compute_bounds(d:Diagram) -> BoundingBox = AsList(_, objs) = d.val ( Point( - minimum $ map (computeSubBound (\p. p.x) neg) objs, - minimum $ map (computeSubBound (\p. p.y) neg) objs + minimum $ each objs (computeSubBound (\p. p.x) neg), + minimum $ each objs (computeSubBound (\p. p.y) neg) ), Point( - maximum $ map (computeSubBound (\p. p.x) id) objs, - maximum $ map (computeSubBound (\p. p.y) id) objs + maximum $ each objs (computeSubBound (\p. p.x) id), + maximum $ each objs (computeSubBound (\p. p.y) id) ) ) @@ -235,11 +236,11 @@ def render_svg(d:Diagram, bounds:BoundingBox) -> String = <+> "height" <=> imgHeight <+> "viewBox" <=> (imgXMin <+> imgYMin <+> imgWidth <+> imgHeight)) tag_brackets_attr "svg" svgAttrStr $ - concat for i. - (attr, pos, geom) = objs[i] + concat $ each objs \obj. + (attr, pos, geom) = obj render_geom attr pos geom -render_scaled_svg = \d. render_svg d (compute_bounds d) +render_scaled_svg = \d:Diagram. render_svg d (compute_bounds d) '## Derived convenience methods and combinators diff --git a/lib/plot.dx b/lib/plot.dx index f5f92a028..189ca543b 100644 --- a/lib/plot.dx +++ b/lib/plot.dx @@ -15,6 +15,7 @@ struct ScaledData(n|Ix, a:Type) = scale : Scale a dat : n => a +-- TODO: bundle up the type params into a triple of types struct Plot(n|Ix, a:Type, b:Type, c:Type) = xs : ScaledData n a ys : ScaledData n b @@ -22,7 +23,7 @@ struct Plot(n|Ix, a:Type, b:Type, c:Type) = Color : Type = Fin 3 => Float -def apply_scale(s:Scale a, x:a) -> Maybe Float given (a) = s.mapping x +def apply_scale(s:Scale a, x:a) -> Maybe Float given (a:Type) = s.mapping x unit_type_scale : Scale(()) = Scale (\_. Just 0.0) (AsList _ [Singleton 0.0]) @@ -33,12 +34,12 @@ def project_unit_interval(x:Float) -> Maybe Float = unit_interval_scale : Scale Float = Scale (project_unit_interval) (AsList _ [Interval 0.0 1.0]) -def map_scale(s:Scale a, f: (b) -> a) -> Scale b given (a, b) = Scale (\x. s.mapping (f x)) s.range +def map_scale(s:Scale a, f: (b) -> a) -> Scale b given (a:Type, b:Type) = Scale (\x. s.mapping (f x)) s.range def float_scale(xmin:Float, xmax:Float) -> Scale Float = map_scale unit_interval_scale (\x. (x - xmin) / (xmax - xmin)) -def get_scaled(sd:ScaledData n a, i:n) -> Maybe Float given (n|Ix, a) = +def get_scaled(sd:ScaledData n a, i:n) -> Maybe Float given (n|Ix, a:Type) = apply_scale sd.scale sd.dat[i] low_color = [1.0, 0.5, 0.0] @@ -54,8 +55,8 @@ def make_rgb_color(c: Color) -> HtmlColor = def color_scale(x:Float) -> HtmlColor = make_rgb_color $ interpolate low_color high_color x -def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a, b, c, n|Ix) = - points = concat_diagrams for i. +def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a:Type, b:Type, c:Type, n|Ix) = + points = concat_diagrams for i:n. x = get_scaled plot.xs i y = get_scaled plot.ys i c = get_scaled plot.cs i @@ -70,16 +71,17 @@ def plot_to_diagram(plot:Plot n a b c) -> Diagram given (a, b, c, n|Ix) = boundingBox = move_xy(rect 1.0 1.0, 0.5, 0.5) boundingBox <> points -def show_plot(plot:Plot n a b c) -> String given (a, b, c, n|Ix) = +def show_plot(plot:Plot n a b c) -> String given (a:Type, b:Type, c:Type, n|Ix) = render_svg (plot_to_diagram plot) (Point 0.0 0.0, Point 1.0 1.0) -def blank_data() ->> ScaledData n () given (n|Ix) = - ScaledData unit_type_scale (for i. ()) +def blank_data(n|Ix) -> ScaledData n () = + ScaledData unit_type_scale (for i:n. ()) -def blank_plot() ->> Plot n () () () given (n|Ix) = - Plot blank_data blank_data blank_data +def blank_plot(n|Ix) -> Plot n () () () = + -- TODO: figure out why we need the annotations here. Top-down inference should work. + Plot(blank_data(n), blank_data(n), blank_data(n)) --- -- TODO: generalize beyond Float with a type class for auto scaling +-- TODO: generalize beyond Float with a type class for auto scaling def auto_scale(xs:n=>Float) -> ScaledData n Float given (n|Ix) = max = maximum xs min = minimum xs @@ -88,29 +90,29 @@ def auto_scale(xs:n=>Float) -> ScaledData n Float given (n|Ix) = padding = maximum [space, max * 0.001, 0.000001] ScaledData (float_scale (min - padding) (max + padding)) xs -def set_x_data(plot:Plot n a b c, xs:ScaledData n new) -> Plot n new b c given (n|Ix, a, b, c, new) = +def set_x_data(plot:Plot n a b c, xs:ScaledData n new) -> Plot n new b c given (n|Ix, a:Type, b:Type, c:Type, new:Type) = -- We can't use `setAt` here because we're changing the type Plot xs plot.ys plot.cs -def set_y_data(plot:Plot n a b c, ys:ScaledData n new) -> Plot n a new c given (n|Ix, a, b, c, new) = +def set_y_data(plot:Plot n a b c, ys:ScaledData n new) -> Plot n a new c given (n|Ix, a:Type, b:Type, c:Type, new:Type) = Plot plot.xs ys plot.cs -def set_c_data(plot:Plot n a b c, cs:ScaledData n new) -> Plot n a b new given (n|Ix, a, b, c, new) = +def set_c_data(plot:Plot n a b c, cs:ScaledData n new) -> Plot n a b new given (n|Ix, a:Type, b:Type, c:Type, new:Type) = Plot plot.xs plot.ys cs def xy_plot(xs:n=>Float, ys:n=>Float) -> Plot n Float Float () given (n|Ix) = - blank_plot | + blank_plot(n) | set_x_data (auto_scale xs) | set_y_data (auto_scale ys) def xyc_plot(xs:n=>Float, ys:n=>Float, cs:n=>Float) -> Plot n Float Float Float given (n|Ix) = - blank_plot | + blank_plot(n) | set_x_data (auto_scale xs) | set_y_data (auto_scale ys) | set_c_data (auto_scale cs) def y_plot(ys:n=>Float) -> Plot n Float Float () given (n|Ix) = - xs = for i. n_to_f $ ordinal i + xs = for i:n. n_to_f $ ordinal i xy_plot xs ys -- xs = linspace (Fin 100) 0. 1.0 @@ -120,14 +122,10 @@ def y_plot(ys:n=>Float) -> Plot n Float Float () given (n|Ix) = -- TODO: scales def matshow(img:n=>m=>Float) -> Html given (n|Ix, m|Ix) = - low = minimum $ for p. - (i, j) = p - img[i,j] - high = maximum $ for p. - (i, j) = p - img[i,j] + low = minimum $ flatten2D(img) + high = maximum $ flatten2D(img) range = high - low - img_to_html $ make_png for i j. + img_to_html $ make_png for i:n j:m. x = if range == 0.0 then float_to_8bit $ 0.5 else float_to_8bit $ (img[i,j] - low) / range diff --git a/lib/png.dx b/lib/png.dx index faa6b7db7..3937b0350 100644 --- a/lib/png.dx +++ b/lib/png.dx @@ -28,7 +28,7 @@ Base64 = Byte -- first two bits should be zero -- This could go in the prelude, or in a library of array-dicing functions. -- An explicit "view" builder would be good here, to avoid copies def get_chunks(chunkSize:Nat, padVal:a, xs:n=>a) - -> List (Fin chunkSize => a) given (n|Ix, a) = + -> List (Fin chunkSize => a) given (n|Ix, a:Type) = numChunks = idiv_ceil (size n) chunkSize paddedSize = numChunks * chunkSize xsPadded = pad_to (Fin paddedSize) padVal xs @@ -44,29 +44,28 @@ def base64s_to_bytes(chunk : Fin 4 => Base64) -> Fin 3 => Byte = def bytes_to_base64s(chunk : Fin 3 => Byte) -> Fin 4 => Base64 = [a, b, c] = chunk -- '?' is 00111111 - map (\x. x .&. '?') $ - [ a .>>. 2 - , (a .<<. 4) .|. (b .>>. 4) - , (b .<<. 2) .|. (c .>>. 6) - , c ] + tmp = [ a .>>. 2 + , (a .<<. 4) .|. (b .>>. 4) + , (b .<<. 2) .|. (c .>>. 6) + , c ] + each tmp \x. x .&. '?' def base64_to_ascii(x:Base64) -> Char = encoding_table[from_ordinal (w8_to_n x)] def encode_chunk(chunk : Fin 3 => Char) -> Fin 4 => Char = - map base64_to_ascii $ bytes_to_base64s chunk + each (bytes_to_base64s chunk) base64_to_ascii -- TODO: the `AsList` unpacking is very tedious. Daniel's change will help def base64_encode(s:String) -> String = AsList(n, cs) = s AsList(numChunks, chunks) = get_chunks 3 '\NUL' cs - encodedChunks = map encode_chunk chunks - flattened = for pair. - (i, j) = pair - encodedChunks[i, j] + encodedChunks = each chunks encode_chunk + FlatIxType : Type = (Fin numChunks, Fin 4) + flattened = flatten2D(encodedChunks) padChars = rem (unsafe_nat_diff 3 (rem n 3)) 3 validOutputChars = unsafe_nat_diff (numChunks * 4) padChars - to_list for i. case ordinal i < validOutputChars of + to_list for i:FlatIxType. case ordinal i < validOutputChars of True -> flattened[i] False -> '=' @@ -74,7 +73,7 @@ def ascii_to_base64(c:Char) -> Maybe Base64 = decoding_table[from_ordinal (w8_to_n c)] def decode_chunk(chunk : Fin 4 => Char) -> Maybe (Fin 3 => Char) = - case seq_maybes $ map ascii_to_base64 chunk of + case chunk | each(ascii_to_base64) | seq_maybes of Nothing -> Nothing Just base64s -> Just $ base64s_to_bytes base64s @@ -87,16 +86,14 @@ def replace(pair:(a,a), x:a) -> a given (a|Eq) = def base64_decode(s:String) -> Maybe String = AsList(n, cs) = s - numValidInputChars = sum for i. b_to_n $ cs[i] /= '=' + numValidInputChars = sum for i:(Fin n). b_to_n $ cs[i] /= '=' numValidOutputChars = idiv (numValidInputChars * 3) 4 - csZeroed = map (\x. replace(('=', 'A'), x)) cs -- swap padding char with 'zero' char + csZeroed = each cs \c. replace(('=', 'A'), c) -- swap padding char with 'zero' char AsList(_, chunks) = get_chunks 4 '\NUL' csZeroed - case seq_maybes $ map decode_chunk chunks of + case chunks | each(decode_chunk) | seq_maybes of Nothing -> Nothing Just decodedChunks -> - resultPadded = for pair. - (i, j) = pair - decodedChunks[i, j] + resultPadded = flatten2D(decodedChunks) Just $ to_list $ slice resultPadded 0 (Fin numValidOutputChars) '## PNG FFI @@ -108,7 +105,7 @@ Gif : Type = String foreign "encodePNG" encodePNG : (RawPtr, Word32, Word32) -> {IO} (Word32, RawPtr) def make_png(img:n=>m=>(Fin 3)=>Word8) -> Png given (n|Ix, m|Ix) = unsafe_io \. - AsList(_, imgFlat) = to_list for triple. + AsList(_, imgFlat) = to_list for triple:(n,(m,Fin 3)). (i, (j, k)) = triple img[i, j, k] with_table_ptr imgFlat \ptr. @@ -116,13 +113,13 @@ def make_png(img:n=>m=>(Fin 3)=>Word8) -> Png given (n|Ix, m|Ix) = unsafe_io \. (sz, ptr') = encodePNG rawPtr (nat_to_rep $ size m) (nat_to_rep $ size n) AsList((rep_to_nat sz), table_from_ptr(Ptr(ptr'))) -def pngs_to_gif(delay:Int, pngs:t=>Png) -> Gif given (t|Ix) = unsafe_io \. - with_temp_files \pngFiles. - for i. write_file pngFiles[i] pngs[i] +def pngs_to_gif(pngs:t=>Png, delay:Int) -> Gif given (t|Ix) = unsafe_io \. + with_temp_files(t) \pngFiles. + for i:t. write_file pngFiles[i] pngs[i] with_temp_file \gifFile. shell_out $ "convert" <> " -delay " <> show delay <> " " <> - concat (for i. "png:" <> pngFiles[i] <> " ") <> + concat (for i:t. "png:" <> pngFiles[i] <> " ") <> "gif:" <> gifFile read_file gifFile @@ -135,7 +132,7 @@ def float_to_8bit(x:Float) -> Word8 = n_to_w8 $ f_to_n $ 255.0 * clip (0.0, 1.0) x def img_to_png(img:n=>m=>(Fin 3)=>Float) -> Png given (n|Ix, m|Ix) = - make_png for i j k. float_to_8bit img[i, j, k] + make_png for i:n j:m k:(Fin 3). float_to_8bit img[i, j, k] '## API entry point @@ -143,4 +140,4 @@ def imshow(img:n=>m=>(Fin 3)=>Float) -> Html given (n|Ix, m|Ix) = img_to_html $ img_to_png img def imseqshow(imgs:t=>n=>m=>(Fin 3)=>Float) -> Html given (t|Ix, n|Ix, m|Ix) = - img_to_html $ pngs_to_gif 50 $ map img_to_png imgs + imgs | each(img_to_png) | pngs_to_gif(50) | img_to_html diff --git a/lib/prelude.dx b/lib/prelude.dx index c2baec9e6..8c62e15b5 100644 --- a/lib/prelude.dx +++ b/lib/prelude.dx @@ -32,7 +32,8 @@ interface Data(a:Type) '### Casting -def internal_cast(x:from) -> to given (from, to) = +@inline +def internal_cast(x:from) -> to given (from:Type, to:Type) = %cast(to, x) def unsafe_coerce(x:from) -> to given (from|Data, to|Data) = %unsafeCoerce(to, x) @@ -59,6 +60,7 @@ Nat = %Nat() NatRep = Word32 def nat_to_rep(x : Nat) -> NatRep = %projNewtype(x) +@inline def rep_to_nat(x : NatRep) -> Nat = %NatCon(x) def n_to_w8(x: Nat) -> Word8 = nat_to_rep x | internal_cast @@ -72,6 +74,7 @@ def n_to_f(x: Nat) -> Float = nat_to_rep x | internal_cast def w8_to_n(x : Word8) -> Nat = internal_cast x | rep_to_nat def w32_to_n(x : Word32) -> Nat = internal_cast x | rep_to_nat +@inline def w64_to_n(x : Word64) -> Nat = internal_cast x | rep_to_nat def i32_to_n(x : Int32) -> Nat = internal_cast x | rep_to_nat def i64_to_n(x : Int64) -> Nat = internal_cast x | rep_to_nat @@ -266,7 +269,7 @@ instance Mul(()) '#### Integral Integer-like things. -interface Integral(a) +interface Integral(a:Type) idiv : (a,a)->a rem : (a,a)->a @@ -298,7 +301,7 @@ instance Integral(Nat) Rational-like things. Includes floating point and two field rational representations. -interface Fractional(a) +interface Fractional(a:Type) divide : (a, a) -> a instance Fractional(Float64) @@ -314,51 +317,47 @@ interface Ix(n|Data) ordinal : (n) -> Nat unsafe_from_ordinal : (Nat) -> n -def size(n|Ix) -> Nat = size'(n=n) +def size(n:Type|Ix) -> Nat = size'(n=n) def Fin(n:Nat) -> Type = %Fin(n) --- version of subtraction on Nats that clamps at zero +# version of subtraction on Nats that clamps at zero def (-|)(x: Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y requires_clamp = %ilt(x', y') - rep_to_nat %select(requires_clamp, 0, (%isub(x', y'))) + rep_to_nat %select(requires_clamp, 0::NatRep, (%isub(x', y'))) def unsafe_nat_diff(x:Nat, y:Nat) -> Nat = x' = nat_to_rep x y' = nat_to_rep y rep_to_nat %isub(x', y') --- `(i..)` parses as `RangeFrom _ i` --- TODO: need to a way to indicate constructor as private -struct RangeFrom(q:Type, i:q) = val : Nat +# TODO: need to a way to indicate constructor as private +struct RangeFrom(i:q) given (q:Type) = val : Nat --- `(i<..)` parses as `RangeFromExc _ i` -struct RangeFromExc(q:Type, i:q) = val : Nat +struct RangeFromExc(i:q) given (q:Type) = val : Nat --- `(..i)` parses as `RangeTo _ i` -struct RangeTo(q:Type, i:q) = val : Nat +struct RangeTo(i:q) given (q:Type) = val : Nat --- `(.. n=>Nat = for i. ordinal i +def iota(n:Type|Ix) -> n=>Nat = for i. ordinal i '## Arithmetic instances for table types @@ -378,28 +377,28 @@ instance Add(n=>a) given (a|Add, n|Ix) instance Sub(n=>a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (i..) => a) given (a|Add, n|Ix) -- Upper triangular tables +instance Add((i:n) => RangeFrom i => a) given (a|Add, n|Ix) # Upper triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (i..) => a) given (a|Sub, n|Ix) -- Upper triangular tables +instance Sub((i:n) => RangeFrom i => a) given (a|Sub, n|Ix) # Upper triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (..i) => a) given (a|Add, n|Ix) -- Lower triangular tables +instance Add((i:n) => RangeTo i => a) given (a|Add, n|Ix) # Lower triangular tables def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (..i) => a) given (a|Sub, n|Ix) -- Lower triangular tables +instance Sub((i:n) => RangeTo i => a) given (a|Sub, n|Ix) # Lower triangular tables def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (.. a) given (a|Add, n|Ix) +instance Add((i:n) => RangeToExc i => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (.. a) given (a|Sub, n|Ix) +instance Sub((i:n) => RangeToExc i => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] -instance Add((i:n) => (i<..) => a) given (a|Add, n|Ix) +instance Add((i:n) => RangeFromExc i => a) given (a|Add, n|Ix) def (+)(xs, ys) = for i. xs[i] + ys[i] zero = for _. zero -instance Sub((i:n) => (i<..) => a) given (a|Sub, n|Ix) +instance Sub((i:n) => RangeFromExc i => a) given (a|Sub, n|Ix) def (-)(xs, ys) = for i. xs[i] - ys[i] instance Mul(n=>a) given (a|Mul, n|Ix) @@ -408,10 +407,10 @@ instance Mul(n=>a) given (a|Mul, n|Ix) '## Basic polymorphic functions and types -def fst(pair:(a, b)) -> a given (a, b) = pair.0 -def snd(pair:(a, b)) -> b given (a, b) = pair.1 +def fst(pair:(a, b)) -> a given (a:Type, b:Type) = pair.0 +def snd(pair:(a, b)) -> b given (a:Type, b:Type) = pair.1 -def swap(pair:(a, b)) -> (b, a) given (a, b) = +def swap(pair:(a, b)) -> (b, a) given (a:Type, b:Type) = (x, y) = pair (y, x) @@ -443,7 +442,7 @@ instance Ix((a, b, c)) given (a|Ix, b|Ix, c|Ix) (i, j, k) = tup ordinal((i,(j,k))) def unsafe_from_ordinal(o) = - (i, (j, k)) = unsafe_from_ordinal o + (i, (j, k)) = unsafe_from_ordinal(n=(a,(b,c))::Type, o) (i, j, k) instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) @@ -452,7 +451,7 @@ instance Ix((a, b, c, d)) given (a|Ix, b|Ix, c|Ix, d|Ix) (i, j, k, m) = tup ordinal((i,(j,(k,m)))) def unsafe_from_ordinal(o) = - (i, (j, (k, m))) = unsafe_from_ordinal o + (i, (j, (k, m))) = unsafe_from_ordinal(n=(a,(b,(c,d)))::Type, o) (i, j, k, m) '## Vector spaces @@ -475,16 +474,16 @@ instance VSpace((a, b)) given (a|VSpace, b|VSpace) (x, y) = pair (s .* x, s .* y) -instance VSpace((i:n) => (..i) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeTo i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (i..) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeFrom i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (.. a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeToExc i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] -instance VSpace((i:n) => (i<..) => a) given (n|Ix, a|VSpace) +instance VSpace((i:n) => RangeFromExc i => a) given (n|Ix, a|VSpace) def (.*)(s, xs) = for i. s .* xs[i] instance VSpace(()) @@ -517,9 +516,9 @@ def not(x:Bool) -> Bool = '## More Boolean operations TODO: move these with the others? --- Can't use `%select` because it lowers to `ISelect`, which requires --- `a` to be a `BaseTy`. -def select(p:Bool, x:a, y:a) -> a given (a) = +# Can't use `%select` because it lowers to `ISelect`, which requires +# `a` to be a `BaseTy`. +def select(p:Bool, x:a, y:a) -> a given (a:Type) = case p of True -> x False -> y @@ -547,14 +546,14 @@ data Maybe(a:Type) = Nothing Just(a) -def is_nothing(x:Maybe a) -> Bool given (a) = +def is_nothing(x:Maybe a) -> Bool given (a:Type) = case x of Nothing -> True Just(_) -> False -def is_just(x:Maybe a) -> Bool given (a) = not $ is_nothing x +def is_just(x:Maybe a) -> Bool given (a:Type) = not $ is_nothing x -def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a, b) = +def maybe(d:b, f:(a)->b, x:Maybe a) -> b given (a:Type, b:Type) = case x of Nothing -> d Just(x') -> f x' @@ -571,14 +570,14 @@ instance Ix(Either(a, b)) given (a|Ix, b|Ix) def unsafe_from_ordinal(o) = as = nat_to_rep $ size a o' = nat_to_rep o - -- TODO: Reshuffle the prelude to be able to use (<) here + # TODO: Reshuffle the prelude to be able to use (<) here case w8_to_b $ %ilt(o', as) of True -> Left $ unsafe_from_ordinal(n=a, o) - -- TODO: Reshuffle the prelude to be able to use `diff_nat` here + # TODO: Reshuffle the prelude to be able to use `diff_nat` here False -> Right $ unsafe_from_ordinal(n=b, rep_to_nat (%isub(o', as))) '## Subtraction on Nats --- TODO: think more about the right API here +# TODO: think more about the right API here def unsafe_i_to_n(x:Int) -> Nat = rep_to_nat $ internal_cast x @@ -593,7 +592,6 @@ def i_to_n(x:Int) -> Maybe Nat = '### Monoid A [monoid](https://en.wikipedia.org/wiki/Monoid) is a things that have an associative binary operator and an identity element. -This is a very useful and general calls of things. It includes: - Addition and Multiplication of Numbers - Boolean Logic @@ -626,55 +624,55 @@ named-instance MulMonoid(a|Mul) -> Monoid(a) '## Effects -def Ref(r:Heap, a|Data) -> Type = %Ref(r, a) -def get(ref:Ref h s) -> {State h} s given (h, s) = %get(ref) -def (:=)(ref:Ref h s, x:s) -> {State h} () given (h, s) = %put(ref, x) +def Ref(r:Heap, a:Type|Data) -> Type = %Ref(r, a) +def get(ref:Ref h s) -> {State h} s given (h:Heap, s|Data) = %get(ref) +def (:=)(ref:Ref h s, x:s) -> {State h} () given (h:Heap, s|Data) = %put(ref, x) -def ask(ref:Ref h r) -> {Read h} r given (h, r) = %ask(ref) +def ask(ref:Ref h r) -> {Read h} r given (h:Heap, r|Data) = %ask(ref) data AccumMonoidData(h:Heap, w:Type) = UnsafeMkAccumMonoidData(b:Type, Monoid b) -interface AccumMonoid(h:Heap, w) +interface AccumMonoid(h:Heap, w:Type) getAccumMonoidData : AccumMonoidData(h, w) -instance AccumMonoid(h, n=>w) given (n|Ix, h, w) (am:AccumMonoid(h, w)) +instance AccumMonoid(h, n=>w) given (n|Ix, h:Heap, w:Type) (am:AccumMonoid(h, w)) getAccumMonoidData = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) UnsafeMkAccumMonoidData(b, bm) def (+=)(ref:Ref h w, x:w) -> {Accum h} () - given (h, w) (am:AccumMonoid(h, w)) = + given (h:Heap, w|Data) (am:AccumMonoid(h, w)) = UnsafeMkAccumMonoidData(b, bm) = %applyMethod0(am) empty = %applyMethod0(bm) %mextend(ref, empty, \x:b y:b. %applyMethod1(bm, x, y), x) -def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h) = %indexRef(ref, i) -def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b, a|Data, h) = ref.0 -def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a, b|Data, h) = ref.1 +def (!)(ref: Ref h (n=>a), i:n) -> Ref h a given (n|Ix, a|Data, h:Heap) = %indexRef(ref, i) +def fst_ref(ref: Ref h (a,b)) -> Ref h a given (b|Data, a|Data, h:Heap) = ref.0 +def snd_ref(ref: Ref h (a,b)) -> Ref h b given (a|Data, b|Data, h:Heap) = ref.1 def run_reader( init:r, - action:(given (h), Ref h r) -> {Read h|eff} a - ) -> {|eff} a given (r|Data, a, eff) = + action:(given (h:Heap), Ref h r) -> {Read h|eff} a + ) -> {|eff} a given (r|Data, a:Type, eff:Effects) = def explicitAction(h':Heap, ref:Ref h' r) -> {Read h'|eff} a = action ref %runReader(init, explicitAction) def with_reader( init:r, - action: (given (h), Ref(h,r)) -> {Read h|eff} a - ) -> {|eff} a given (r|Data, a, eff) = + action: (given (h:Heap), Ref(h,r)) -> {Read h|eff} a + ) -> {|eff} a given (r|Data, a:Type, eff:Effects) = run_reader(init, action) def MonoidLifter(b:Type, w:Type) -> Type = - (given (h) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) + (given (h:Heap) (AccumMonoid(h, b))) ->> AccumMonoid(h, w) -named-instance mk_accum_monoid (given (h, w), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) +named-instance mk_accum_monoid (given (h:Heap, w:Type), d:AccumMonoidData(h, w)) -> AccumMonoid(h, w) getAccumMonoidData = d def run_accum( bm:Monoid b, - action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a - ) -> {|eff} (a, w) given (a, b, w|Data, eff) (MonoidLifter(b,w)) = + action: (given (h:Heap) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a + ) -> {|eff} (a, w) given (a:Type, b:Type, w|Data, eff:Effects) (MonoidLifter(b,w)) = empty = %applyMethod0(bm) def explicitAction(h':Heap, ref:Ref h' w) -> {Accum h'|eff} a = accumMonoidData : AccumMonoidData h' b = UnsafeMkAccumMonoidData b bm @@ -684,32 +682,32 @@ def run_accum( def yield_accum( m:Monoid b, - action: (given (h) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a - ) -> {|eff} w given (a, b, w|Data, eff) (MonoidLifter b w) = + action: (given (h:Heap) (AccumMonoid(h, b)), Ref h w) -> {Accum h|eff} a + ) -> {|eff} w given (a:Type, b:Type, w|Data, eff:Effects) (MonoidLifter b w) = snd $ run_accum(m, action) def run_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} (a,s) given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} (a,s) given (a:Type, s|Data, eff:Effects) = def explicitAction(h':Heap, ref:Ref h' s) -> {State h'|eff} a = action ref %runState(init, explicitAction) def with_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} a given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} a given (a:Type, s|Data, eff:Effects) = fst $ run_state(init, action) def yield_state( init:s, - action: (given (h), Ref h s) -> {State h |eff} a - ) -> {|eff} s given (a, s|Data, eff) = + action: (given (h:Heap), Ref h s) -> {State h |eff} a + ) -> {|eff} s given (a:Type, s|Data, eff:Effects) = snd $ run_state(init, action) def unsafe_io( f:()->{IO|eff} a - ) -> {|eff} a given (a, eff) = + ) -> {|eff} a given (a:Type, eff:Effects) = f' : (() -> {IO|eff} a) = \. f() %runIO(f') @@ -848,11 +846,11 @@ instance Ord(Nat) def (>)(x, y) = nat_to_rep x > nat_to_rep y def (<)(x, y) = nat_to_rep x < nat_to_rep y --- TODO: we want Eq and Ord for all index sets, not just `Fin n` -instance Eq(Fin n) given (n) +# TODO: we want Eq and Ord for all index sets, not just `Fin n` +instance Eq(Fin n) given (n:Nat) def (==)(x, y) = ordinal x == ordinal y -instance Ord(Fin n) given (n) +instance Ord(Fin n) given (n:Nat) def (>)(x, y) = ordinal x > ordinal y def (<)(x, y) = ordinal x < ordinal y @@ -870,7 +868,7 @@ instance Ix(Maybe a) given (a|Ix) Nothing -> size a def unsafe_from_ordinal(o) = case o == size a of - False -> Just $ unsafe_from_ordinal o + False -> Just $ unsafe_from_ordinal(n=a, o) True -> Nothing interface NonEmpty(n|Ix) @@ -888,10 +886,10 @@ instance NonEmpty((a,b)) given (a|NonEmpty, b|NonEmpty) instance NonEmpty(Either(a,b)) given (a|NonEmpty, b|Ix) first_ix = unsafe_from_ordinal 0 --- The below instance is valid, but causes "multiple candidate dictionaries" --- errors if both Left and Right are NonEmpty. --- instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] --- first_ix = unsafe_from_ordinal _ 0 +# The below instance is valid, but causes 'multiple candidate dictionaries' +# errors if both Left and Right are NonEmpty. +# instance NonEmpty (a|b) given {a b} [Ix a, NonEmpty b] +# first_ix = unsafe_from_ordinal _ 0 instance NonEmpty(Maybe a) given (a|Ix) first_ix = unsafe_from_ordinal 0 @@ -916,13 +914,13 @@ def left_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == 0 then Nothing - else Just $ unsafe_from_ordinal $ ix -| 1 + else Just $ unsafe_from_ordinal(n=n, ix -| 1) def right_fence(p:Post n) -> Maybe n given (n|Ix) = ix = ordinal p if ix == size n then Nothing - else Just $ unsafe_from_ordinal ix + else Just $ unsafe_from_ordinal(n=n, ix) def last_ix() ->> n given (n|NonEmpty) = unsafe_from_ordinal(unsafe_i_to_n(n_to_i(size n) - 1)) @@ -930,19 +928,6 @@ def last_ix() ->> n given (n|NonEmpty) = instance NonEmpty(Post n) given (n|Ix) first_ix = unsafe_from_ordinal(n=Post n, 0) -def scan( - init:a, - body:(n, a)->(a,b) - ) -> (a, n=>b) given (a|Data, b, n|Ix) = - swap $ run_state(init) \s. for i. - c = get s - (c', y) = body(i, c) - s := c' - y - -def fold(init:a, body:(n,a)->a) -> a given (n|Ix, a|Data) = - fst $ scan init \i x. (body(i, x), ()) - def compare(x:a, y:a) -> Ordering given (a|Ord) = if x < y then LT @@ -961,42 +946,32 @@ instance Monoid(Ordering) instance Eq(n=>a) given (n|Ix, a|Eq) def (==)(xs, ys) = yield_accum AndMonoid \ref. - for i. ref += xs[i] == ys[i] - -instance Ord(n=>a) given (n|Ix, a|Ord) - def (>)(xs, ys) = - f: Ordering = - fold EQ $ \i c. c <> compare(xs[i], ys[i]) - f == GT - def (<)(xs, ys) = - f: Ordering = - fold EQ $ \i c. c <> compare(xs[i], ys[i]) - f == LT + for i:n. ref += xs[i] == ys[i] '## Subset class -interface Subset(subset, superset) +interface Subset(subset:Type, superset:Type) inject' : (subset) -> superset project' : (superset) -> Maybe subset unsafe_project' : (superset) -> subset --- wrappers with more helpful implicit arg names -def inject(x:from) -> to given (to, from) (Subset(from, to)) = inject'(x) -def project(x:from) -> Maybe to given (to, from) (Subset(to, from)) = project'(x) -def unsafe_project(x:from) -> to given (to, from) (Subset(to, from)) = unsafe_project'(x) +# wrappers with more helpful implicit arg names +def inject(x:from) -> to given (to:Type, from:Type) (Subset(from, to)) = inject'(x) +def project(x:from) -> Maybe to given (to:Type, from:Type) (Subset(to, from)) = project'(x) +def unsafe_project(x:from) -> to given (to:Type, from:Type) (Subset(to, from)) = unsafe_project'(x) -instance Subset(a, c) given (a, b, c) (Subset(a, b), Subset(b, c)) +instance Subset(a, c) given (a:Type, b:Type, c:Type) (Subset(a, b), Subset(b, c)) def inject'(x) = inject $ inject(to=b, x) def project'(x) = case project(to=b, x) of Nothing -> Nothing Just(y)-> project y def unsafe_project'(x) = unsafe_project $ unsafe_project(to=b, x) -def unsafe_project_rangefrom(j:q) -> RangeFrom(q, i) given (q|Ix, i:q) = +def unsafe_project_rangefrom(j:q) -> RangeFrom(i) given (q|Ix, i:q) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) -instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) +instance Subset(RangeFrom(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i def project'(j) = @@ -1007,7 +982,7 @@ instance Subset(RangeFrom(q, i), q) given (q|Ix, i:q) else Just $ RangeFrom $ unsafe_nat_diff(j', i') def unsafe_project'(j) = RangeFrom unsafe_nat_diff(ordinal j, ordinal i) -instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) +instance Subset(RangeFromExc(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal $ j.val + ordinal i + 1 def project'(j) = j' = ordinal j @@ -1018,7 +993,7 @@ instance Subset(RangeFromExc(q, i), q) given (q|Ix, i:q) def unsafe_project'(j) = RangeFromExc unsafe_nat_diff(ordinal j, ordinal i + 1) -instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) +instance Subset(RangeTo(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1028,7 +1003,7 @@ instance Subset(RangeTo(q, i), q) given (q|Ix, i:q) else Just $ RangeTo j' def unsafe_project'(j) = RangeTo (ordinal j) -instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) +instance Subset(RangeToExc(i), q) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1038,7 +1013,7 @@ instance Subset(RangeToExc(q, i), q) given (q|Ix, i:q) else Just $ RangeToExc j' def unsafe_project'(j) = RangeToExc (ordinal j) -instance Subset(RangeToExc(q, i), RangeTo(q, i)) given (q|Ix, i:q) +instance Subset(RangeToExc(i), RangeTo(i)) given (q|Ix, i:q) def inject'(j) = unsafe_from_ordinal j.val def project'(j) = j' = ordinal j @@ -1078,14 +1053,14 @@ interface Floating(a:Type) def lbeta(x:a, y:a) -> a given (a|Sub|Floating) = lgamma x + lgamma y - lgamma (x + y) --- Todo: better numerics for very large and small values. --- Using %exp here to avoid circular definition problems. +# Todo: better numerics for very large and small values. +# Using %exp here to avoid circular definition problems. def float32_sinh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_cosh(x:Float32) -> Float32 = %fdiv(%fadd(%exp(x), %exp(%fsub(0.0,x))), 2.0) def float32_tanh(x:Float32) -> Float32 = %fdiv(%fsub(%exp(x), %exp(%fsub(0.0,x))) ,%fadd(%exp(x), %exp(%fsub(0.0,x)))) --- Todo: unify this with float32 functions. +# Todo: unify this with float32 functions. def float64_sinh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_cosh(x:Float64) -> Float64 = %fdiv(%fadd(%exp(x), %exp(%fsub(f_to_f64 0.0, x))), f_to_f64 2.0) def float64_tanh(x:Float64) -> Float64 = %fdiv(%fsub(%exp(x), %exp(%fsub(f_to_f64 0.0, x))) @@ -1140,7 +1115,7 @@ instance Floating(Float32) struct Ptr(a:Type) = val : RawPtr -def cast_ptr(ptr: Ptr a) -> Ptr b given (a, b) = Ptr(ptr.val) +def cast_ptr(ptr: Ptr from) -> Ptr to given (from:Type, to:Type) = Ptr(ptr.val) interface Storable(a|Data) store : (Ptr a, a) -> {IO} () @@ -1168,43 +1143,44 @@ instance Storable(Float32) def storage_size() = 4 instance Storable(Nat) - def store(ptr, x) = store(Ptr(ptr.val), nat_to_rep x) + def store(ptr, x) = store(cast_ptr(ptr, to=%Word32()), nat_to_rep x) def load(ptr) = rep_to_nat $ load(Ptr(ptr.val)) def storage_size() = storage_size(a=NatRep) -instance Storable(Ptr a) given (a) +instance Storable(Ptr a) given (a:Type) def store(ptr, x) = %ptrStore(internal_cast(to=%PtrPtr(), ptr.val), x.val) def load(ptr) = Ptr(%ptrLoad(internal_cast(to=%PtrPtr(), ptr))) - def storage_size() = 8 -- TODO: something more portable? + def storage_size() = 8 # TODO: something more portable? --- TODO: Storable instances for other types +# TODO: Storable instances for other types def malloc(n:Nat) -> {IO} (Ptr a) given (a|Storable) = numBytes = storage_size(a=a) * n Ptr(%alloc(nat_to_rep numBytes)) -def free(ptr:Ptr a) -> {IO} () given (a) = %free(ptr.val) +def free(ptr:Ptr a) -> {IO} () given (a:Type) = %free(ptr.val) def (+>>)(ptr:Ptr a, i:Nat) -> Ptr a given (a|Storable) = i' = nat_to_rep $ i * storage_size(a=a) Ptr(%ptrOffset(ptr.val, i')) --- TODO: consider making a Storable instance for tables instead +# TODO: consider making a Storable instance for tables instead def store_table(ptr: Ptr a, tab:n=>a) -> {IO} () given (a|Storable, n|Ix) = - for_ i. store(ptr +>> ordinal i, tab[i]) + for_ i:n. store(ptr +>> ordinal i, tab[i]) def memcpy(dest:Ptr a, src:Ptr a, n:Nat) -> {IO} () given (a|Storable) = for_ i:(Fin n). i' = ordinal i store(dest +>> i', load $ src +>> i') --- TODO: generalize these brackets to allow other effects --- TODO: make sure that freeing happens even if there are run-time errors +# TODO: generalize these brackets to allow other effects +# TODO: make sure that freeing happens even if there are run-time errors def with_alloc( + a|Storable, n:Nat, action: (Ptr a) -> {IO} b - ) -> {IO} b given (a|Storable, b) = - ptr = malloc n + ) -> {IO} b given (b:Type) = + ptr = malloc(a=a, n) result = action ptr free ptr result @@ -1212,9 +1188,9 @@ def with_alloc( def with_table_ptr( xs:n=>a, action: (Ptr a) -> {IO} b - ) -> {IO} b given (a|Storable, b, n|Ix) = - ptr <- with_alloc(size n) - for i. store(ptr +>> ordinal i, xs[i]) + ) -> {IO} b given (a|Storable, b:Type, n|Ix) = + ptr <- with_alloc(a, size n) + for i:n. store(ptr +>> ordinal i, xs[i]) action ptr def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = @@ -1224,22 +1200,31 @@ def table_from_ptr(ptr:Ptr a) -> {IO} n=>a given (a|Storable, n|Ix) = pi : Float = 3.141592653589793 -def id(x:a) -> a given (a) = x -def dup(x:a) -> (a, a) given (a) = (x, x) -def map(f:(a)->{|eff} b, xs: n=>a) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = - for i. f xs[i] --- map, flipped so that the function goes last -def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a, b, n|Ix, eff) = +def id(x:a) -> a given (a:Type) = x +def dup(x:a) -> (a, a) given (a:Type) = (x, x) +# map, flipped so that the function goes last +def each(xs: n=>a, f:(a)->{|eff} b) -> {|eff} (n=>b) given (a:Type, b:Type, n|Ix, eff:Effects) = for i. f xs[i] -def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a, b, n|Ix) = for i. (xs[i], ys[i]) -def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a, b, n|Ix)= (each xys fst, each xys snd) -def fanout(x:a) -> n=>a given (n|Ix, a) = for i. x +def zip(xs:n=>a, ys:n=>b) -> (n=>(a,b)) given (a:Type, b:Type, n|Ix) = for i. (xs[i], ys[i]) +def unzip(xys:n=>(a,b)) -> (n=>a , n=>b) given (a:Type, b:Type, n|Ix) = + (each xys \xy. fst(xy), each xys \xy. snd(xy)) +def fanout(x:a) -> n=>a given (n|Ix, a:Type) = for i. x def sq(x:a) -> a given (a|Mul) = x * x -def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, zero - x) +def abs(x:a) -> a given (a|Sub|Ord) = select(x > zero, x, (zero::a) - x) def mod(x:a, y:a) -> a given (a|Add|Integral) = rem(y + rem(x, y), y) -def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a, b, c) = \x. g(f(x)) -def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a, b, c) = \x. f(g(x)) +def (>>>)(f:(a) -> b, g:(b) -> c) -> (a) -> c given (a:Type, b:Type, c:Type) = \x. g(f(x)) +def (<<<)(f:(b) -> c, g:(a) -> b) -> (a) -> c given (a:Type, b:Type, c:Type) = \x. f(g(x)) + +def flatten2D(mat:n=>m=>a) -> (n,m)=>a given (n|Ix, m|Ix, a:Type) = + for pair. + (i, j) = pair + mat[i,j] + +def flatten3D(array:l=>n=>m=>a) -> (l,n,m)=>a given (l|Ix, n|Ix, m|Ix, a:Type) = + for triple. + (i, j, k) = triple + array[i,j,k] '## Table Operations @@ -1267,23 +1252,34 @@ instance Floating(n=>a) given (a|Floating, n|Ix) '### Reductions --- `combine` should be a commutative and associative, and form a --- commutative monoid with `identity` -def reduce(identity:a, combine:(a,a)->a, xs:n=>a) -> a given (a|Data, n|Ix) = - -- TODO: implement with the accumulator effect - fold identity \i c. combine(c, xs[i]) +def scan( + init:c, + xs: n=>a, + body:(n, a, c)->(b, c) + ) -> (n=>b, c) given (a:Type, b:Type, c|Data, n|Ix) = + run_state(init) \ref. for i:n. + carry = get ref + (y, carry') = body(i, xs[i], carry) + ref := carry' + y + +def fold(init:c, xs:n=>a, body:(n, a, c)-> c) -> c given (a:Type, n|Ix, c|Data) = + snd $ scan(init, xs) \i x carry. ((), body(i, x, carry)) + +# `combine` should be a commutative and associative, and form a +# commutative monoid with `identity` +def reduce(xs:n=>a, identity:a, combine:(a,a)->a) -> a given (a|Data, n|Ix) = + # TODO: implement with the accumulator effect + fold(identity, xs) \i x c. combine(c, x) --- TODO: call this `scan` and call the current `scan` something else -def scan'(init:a, body:(n,a)->a) -> n=>a given (a|Data, n|Ix) = - snd $ scan init \i x. dup(body(i, x)) def fsum(xs:n=>Float) -> Float given (n|Ix) = - yield_accum(AddMonoid Float) \ref. for i. ref += xs[i] -def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(zero, (+), xs) -def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(one , (*), xs) + yield_accum(AddMonoid Float) \ref. each xs \x. ref += x +def sum(xs:n=>v) -> v given (n|Ix, v|Add) = reduce(xs, zero, (+)) +def prod(xs:n=>v) -> v given (n|Ix, v|Mul) = reduce(xs, one , (*)) def mean(xs:n=>v) -> v given (n|Ix, v|VSpace) = sum xs / n_to_f (size n) def std(xs:n=>v) -> v given (n|Ix, v|Mul|Sub|VSpace|Floating) = sqrt $ mean (each xs sq) - sq (mean xs) -def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(False, (||), xs) -def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(True , (&&), xs) +def any(xs:n=>Bool) -> Bool given (n|Ix) = reduce(xs, False, (||)) +def all(xs:n=>Bool) -> Bool given (n|Ix) = reduce(xs, True , (&&)) '### apply_n @@ -1296,15 +1292,15 @@ TODO: Move this to be with reductions? It's a kind of `scan`. def cumsum(xs: n=>a) -> n=>a given (n|Ix, a|Add) = - total <- with_state zero - for i. + total <- with_state (zero::a) + for i:n. newTotal = get total + xs[i] total := newTotal newTotal def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = - total <- with_state zero - for i. + total <- with_state (zero::a) + for i:n. oldTotal = get total total := oldTotal + xs[i] oldTotal @@ -1313,27 +1309,27 @@ def cumsum_low(xs: n=>a) -> n=>a given (n|Ix, a|Add) = '### AD operations --- TODO: add vector space constraints -def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a, b) = - %linearize(\x. f x, x) +# TODO: add vector space constraints +def linearize(f:(a)->b, x:a) -> (b, (a)->b) given (a:Type, b:Type) = + %linearize(\x:a. f x, x) -def jvp(f:(a)->b, x:a, t:a) -> b given (a, b) = (snd $ linearize(f, x))(t) -def transpose_linear(f:(a)->b) -> (b)->a given (a, b) = \ct. - %linearTranspose(\x. f x, ct) +def jvp(f:(a)->b, x:a, t:a) -> b given (a:Type, b:Type) = (snd $ linearize(f, x))(t) +def transpose_linear(f:(a)->b) -> (b)->a given (a:Type, b:Type) = \ct. + %linearTranspose(\x:a. f x, ct) -def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a, b) = +def vjp(f:(a)->b, x:a) -> (b, (b)->a) given (a:Type, b:Type) = (y, df) = linearize(f, x) (y, transpose_linear df) -def grad(f:(a)->Float, x:a) -> a given (a) = (snd vjp(f, x))(1.0) +def grad(f:(a)->Float, x:a) -> a given (a:Type) = (snd vjp(f, x))(1.0) def deriv(f:(Float)->Float, x:Float) -> Float = jvp(f, x, 1.0) def deriv_rev(f:(Float)->Float, x:Float) -> Float = (snd vjp(f, x))(1.0) --- XXX: Watch out when editing this data type! We depend on its structure --- deep inside the compiler (mostly in linearization and during rule registration). -data SymbolicTangent(a) = +# XXX: Watch out when editing this data type! We depend on its structure +# deep inside the compiler (mostly in linearization and during rule registration). +data SymbolicTangent(a:Type) = ZeroTangent SomeTangent(a) @@ -1345,15 +1341,15 @@ def someTangent(x:SymbolicTangent a) -> a given (a|VSpace) = '### Approximate Equality TODO: move this outside the AD section to be with equality? -interface HasAllClose(a) +interface HasAllClose(a:Type) allclose : (a, a, a, a) -> Bool -interface HasDefaultTolerance(a) +interface HasDefaultTolerance(a:Type) default_atol : a default_rtol : a def (~~)(x:a, y:a) -> Bool given (a|HasAllClose|HasDefaultTolerance) = - allclose(default_atol, default_rtol, x, y) + allclose(a=a, default_atol, default_rtol, x, y) instance HasAllClose(Float32) def allclose(atol, rtol, x, y) = abs (x - y) <= (atol + rtol * abs y) @@ -1402,7 +1398,7 @@ def check_deriv(f:(Float)->Float, x:Float) -> Bool = '## Length-erased lists -data List(a)= +data List(a:Type) = AsList(n:Nat, elements:(Fin n => a)) instance Eq(List a) given (a|Eq) @@ -1414,15 +1410,15 @@ instance Eq(List a) given (a|Eq) else all for i:(Fin nx). xs[i] == ys[unsafe_from_ordinal (ordinal i)] -def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a) = +def unsafe_cast_table(xs:from=>a) -> to=>a given (to|Ix, from|Ix, a:Type) = for i. xs[unsafe_from_ordinal (ordinal i)] -def to_list(xs:n=>a) -> List a given (n|Ix, a) = +def to_list(xs:n=>a) -> List a given (n|Ix, a:Type) = n' = size n - AsList(_, unsafe_cast_table(to=Fin n', xs)) + AsList(n', unsafe_cast_table(to=Fin n', xs)) instance Monoid(List a) given (a|Data) - mempty = AsList(_, []) + mempty = to_list([] :: Fin 0 => a) def (<>)(x, y) = AsList(nx,xs) = x AsList(ny,ys) = y @@ -1437,14 +1433,14 @@ named-instance ListMonoid (a|Data) -> Monoid(List a) mempty = mempty def (<>)(x, y) = x <> y --- TODO Eliminate or reimplement this operation, since it costs O(n) --- where n is the length of the list held in the reference. +# TODO Eliminate or reimplement this operation, since it costs O(n) +# where n is the length of the list held in the reference. def append(list: Ref(h, List a), x:a) -> {Accum h} () - given (a|Data, h) (AccumMonoid(h, List a)) = + given (a|Data, h:Heap) (AccumMonoid(h, List a)) = list += to_list [x] --- TODO: replace `slice` with this? -def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a) = +# TODO: replace `slice` with this? +def post_slice(xs:n=>a, start:Post n, end:Post n) -> List a given (n|Ix, a:Type) = slice_size = unsafe_nat_diff(ordinal end, ordinal start) to_list for i:(Fin slice_size). xs[unsafe_from_ordinal(n=n, ordinal i + ordinal start)] @@ -1456,17 +1452,17 @@ String : Type = List Char def string_from_char_ptr(n:Word32, ptr:Ptr Char) -> {IO} String = AsList(rep_to_nat n, table_from_ptr ptr) --- TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint +# TODO. This is ASCII code point. It really should be Int32 for Unicode codepoint def codepoint(c:Char) -> Int = w8_to_i c struct CString = ptr : RawPtr --- TODO: check the string contains no nulls +# TODO: check the string contains no nulls def with_c_string( s:String, action: (CString) -> {IO} a - ) -> {IO} a given (a) = + ) -> {IO} a given (a:Type) = AsList(n, s') = s <> "\NUL" with_table_ptr s' \ptr. action CString(ptr.val) @@ -1477,7 +1473,7 @@ No particular promises are made to exactly what that representation will contain In particular it is **not** promised to be parseable. Nor does it promise a particular level of precision for numeric values. -interface Show(a) +interface Show(a:Type) show : (a) -> String instance Show(String) @@ -1535,7 +1531,7 @@ instance Show((a, b, c, d)) given (a|Show, b|Show, c|Show, d|Show) '### Parse interface For types that can be parsed from a `String`. -interface Parse(a) +interface Parse(a:Type) parseString : (String) -> Maybe a foreign "strtof" strtofFFI : (RawPtr, RawPtr) -> {IO} Float @@ -1544,7 +1540,7 @@ instance Parse(Float) def parseString(str) = unsafe_io \. AsList(str_len, _) = str with_c_string str \cStr. - with_alloc 1 \end_ptr:(Ptr (Ptr Char)). + with_alloc (Ptr Char) 1 \end_ptr. result = strtofFFI(cStr.ptr, end_ptr.val) str_end_ptr = load end_ptr consumed = raw_ptr_to_i64 str_end_ptr.val - raw_ptr_to_i64 cStr.ptr @@ -1567,15 +1563,15 @@ def copysign(a:Float, b:Float) -> Float = True -> (-a) False -> 0.0 --- Todo: use IEEE floating-point builtins. +# Todo: use IEEE floating-point builtins. infinity = 1.0 / 0.0 nan = 0.0 / 0.0 --- Todo: use IEEE floating-point builtins. +# Todo: use IEEE floating-point builtins. def isinf(x:Float) -> Bool = (x == infinity) || (x == -infinity) def isnan(x:Float) -> Bool = not (x >= x && x <= x) --- Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. +# Todo: use IEEE-754R 5.11: Floating Point Comparison Relation cmpUnordered. def either_is_nan(x:Float, y:Float) -> Bool = (isnan x) || (isnan y) '## File system operations @@ -1585,7 +1581,7 @@ FilePath : Type = String def is_null_raw_ptr(ptr:RawPtr) -> Bool = raw_ptr_to_i64 ptr == 0 -def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a) = +def from_nullable_raw_ptr(ptr:RawPtr) -> Maybe (Ptr a) given (a:Type) = if is_null_raw_ptr ptr then Nothing else Just $ Ptr ptr @@ -1615,7 +1611,7 @@ def fopen(path:String, mode:StreamMode) -> {IO} (Stream mode) = with_c_string modeStr \cMode. Stream $ fopenFFI(cPath.ptr, cMode.ptr) -def fclose(stream:Stream mode) -> {IO} () given (mode) = +def fclose(stream:Stream mode) -> {IO} () given (mode:StreamMode) = fcloseFFI stream.ptr () @@ -1629,7 +1625,7 @@ def fwrite(stream:Stream WriteMode, s:String) -> {IO} () = '### Iteration TODO: move this out of the file-system section -def while(body: () -> {|eff} Bool) -> {|eff} () given (eff) = +def while(body: () -> {|eff} Bool) -> {|eff} () given (eff:Effects) = body' : () -> {|eff} Word8 = \. b_to_w8 $ body() %while(body') @@ -1637,19 +1633,14 @@ data IterResult(a|Data) = Continue Done(a) --- TODO: can we improve effect inference so we don't need this? -def lift_state(ref: Ref(h, c), f:(a) -> {|eff} b, x:a) -> {State h|eff} b - given (a, b, c, h, eff) = - f x - --- A little iteration combinator -def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff) = - result = yield_state Nothing \resultRef. - i <- with_state 0 +# A little iteration combinator +def iter(body: (Nat) -> {|eff} IterResult a) -> {|eff} a given (a|Data, eff:Effects) = + result = yield_state (Nothing::Maybe a) \resultRef. + i <- with_state (0::Nat) while \. continue = is_nothing $ get resultRef if continue then - case lift_state(resultRef, (\x. lift_state(i, body, x)), get i) of + case body(get(i)) of Continue -> i := get i + 1 Done(result) -> resultRef := Just result continue @@ -1661,7 +1652,7 @@ def bounded_iter( maxIters:Nat, fallback:a, body:(Nat) -> {|eff} IterResult a - ) -> {|eff} a given (a|Data, eff) = iter \i. + ) -> {|eff} a given (a|Data, eff:Effects) = iter \i. if i >= maxIters then Done fallback else body i @@ -1691,7 +1682,6 @@ def error(s:String) -> a given (a|Data) = unsafe_io \. def todo() ->> a given (a|Data) = error "TODO: implement it!" - '### Table operations @noinline @@ -1703,14 +1693,14 @@ def from_ordinal(i:Nat) -> n given (n|Ix) = True -> unsafe_from_ordinal i False -> error $ from_ordinal_error(i, size n) --- TODO: should this be called `from_ordinal`? +# TODO: should this be called `from_ordinal`? def to_ix(i:Nat) -> Maybe n given (n|Ix) = case i < size n of True -> Just $ unsafe_from_ordinal i False -> Nothing --- TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy --- TODO: safe (runtime-checked) and unsafe versions +# TODO: could make an `unsafeCastIndex` and this could avoid the runtime copy +# TODO: safe (runtime-checked) and unsafe versions def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = case size from == size to of True -> unsafe_cast_table xs @@ -1720,12 +1710,12 @@ def cast_table(xs:to=>a) -> from=>a given (from|Ix, to|Ix, a|Data) = def asidx(i:Nat) -> n given (n|Ix) = from_ordinal i def (@)(i:Nat, n|Ix) -> n = from_ordinal i -def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a) = +def slice(xs:n=>a, start:Nat, m|Ix) -> m=>a given (n|Ix, a:Type) = for i. xs[from_ordinal (ordinal i + start)] -def head(xs:n=>a) -> a given (n|Ix, a) = xs[0@_] +def head(xs:n=>a) -> a given (n|Ix, a:Type) = xs[0@_] -def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a) = +def tail(xs:n=>a, start:Nat) -> List a given (n|Ix, a:Type) = numElts = size n -| start to_list $ slice(xs, start, Fin numElts) @@ -1736,19 +1726,19 @@ Dex's PRNG system is modelled directly after [JAX's](https://github.com/google/j '### Key functions --- TODO: newtype +# TODO: newtype Key = Word64 @noinline def threefry_2x32(k:Word64, count:Word64) -> Word64 = - -- Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins - rotations1 = [13, 15, 26, 6] - rotations2 = [17, 29, 16, 24] + # Based on jax's threefry_2x32 by Matt Johnson and Peter Hawkins + rotations1 : Fin 4 => Int32 = [13, 15, 26, 6] + rotations2 : Fin 4 => Int32 = [17, 29, 16, 24] k0 = low_word k k1 = high_word k - -- TODO: add a fromHex - k2 = k0 .^. k1 .^. (n_to_w32 466688986) -- 0x1BD11BDA + # TODO: add a fromHex + k2 = k0 .^. k1 .^. (n_to_w32 466688986) # 0x1BD11BDA x = low_word count y = high_word count @@ -1758,9 +1748,9 @@ def threefry_2x32(k:Word64, count:Word64) -> Word64 = rotations = [rotations1, rotations2] ks = [k1, k2, k0] (x, y) = yield_state (x, y) \ref. for i:(Fin 5). - for j. + for j:(Fin 4). (x, y) = get ref - rotationIndex = unsafe_from_ordinal (ordinal i `mod` 2) + rotationIndex : Fin 2 = unsafe_from_ordinal (ordinal i `mod` 2) rot = rotations[rotationIndex, j] x = x + y y = (y .<<. rot) .|. (y .>>. (32 - rot)) @@ -1777,7 +1767,7 @@ def hash(x:Key, y:Nat) -> Key = y64 = n_to_w64 y threefry_2x32(x, y64) def new_key(x:Nat) -> Key = hash(0, x) -def many(f:(Key)->a, k:Key, i:n) -> a given (a, n|Ix) = f hash(k, ordinal i) +def many(f:(Key)->a, k:Key, i:n) -> a given (a:Type, n|Ix) = f hash(k, ordinal i) def ixkey(k:Key, i:n) -> Key given (n|Ix) = hash(k, ordinal i) def split_key(k:Key) -> Fin n => Key given (n:Nat) = for i. ixkey(k, i) @@ -1786,26 +1776,26 @@ These functions generate samples taken from, different distributions. Such as `rand_mat` with samples from the distribution of floating point matrices where each element is taken from a i.i.d. uniform distribution. Note that additional standard distributions are provided by the `stats` library. def rand(k:Key) -> Float = - exponent_bits = 1065353216 -- 1065353216 = 127 << 23 - mantissa_bits = (high_word k .&. 8388607) -- 8388607 == (1 << 23) - 1 + exponent_bits : Word32 = 1065353216 # 1065353216 = 127 << 23 + mantissa_bits = (high_word k .&. 8388607) # 8388607 == (1 << 23) - 1 bits = exponent_bits .|. mantissa_bits %bitcast(Float, bits) - 1.0 -def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a) = +def rand_vec(n:Nat, f: (Key) -> a, k: Key) -> Fin n => a given (a:Type) = for i:(Fin n). f ixkey(k, i) -def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a) = +def rand_mat(n:Nat, m:Nat, f: (Key) -> a, k: Key) -> Fin n => Fin m => a given (a:Type) = for i j. f ixkey(k, (i, j)) def randn(k:Key) -> Float = - [k1, k2] = split_key k - -- rand is uniform between 0 and 1, but implemented such that it rounds to 0 - -- (in float32) once every few million draws, but never rounds to 1. + [k1, k2] = split_key(n=2, k) + # rand is uniform between 0 and 1, but implemented such that it rounds to 0 + # (in float32) once every few million draws, but never rounds to 1. u1 = 1.0 - (rand k1) u2 = rand k2 sqrt ((-2.0) * log u1) * cos (2.0 * pi * u2) --- TODO: Make this better... +# TODO: Make this better... def rand_int(k:Key) -> Nat = w64_to_n k `mod` 2147483647 def randn_vec(k:Key) -> n=>Float given (n|Ix) = @@ -1823,12 +1813,12 @@ instance InnerProd(Float) def inner_prod(x, y) = x * y instance InnerProd(n=>a) given (a|InnerProd, n|Ix) - def inner_prod(x, y) =sum for i. inner_prod(x[i], y[i]) + def inner_prod(x, y) =sum for i:n. inner_prod(x[i], y[i]) '## Arbitrary Type class for generating example values -interface Arbitrary(a) +interface Arbitrary(a:Type) arb : (Key) -> a instance Arbitrary(Bool) @@ -1846,24 +1836,24 @@ instance Arbitrary(Nat) instance Arbitrary(n=>a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(.. a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeToExc i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(..i) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeTo i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(i..) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeFrom i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) -instance Arbitrary((i:n)=>(i<..) => a) given (n|Ix, a|Arbitrary) +instance Arbitrary((i:n)=> RangeFromExc i => a) given (n|Ix, a|Arbitrary) def arb(key) = for i. arb $ ixkey(key, i) instance Arbitrary((a, b)) given (a|Arbitrary, b|Arbitrary) def arb(key) = - [k1, k2] = split_key key + [k1, k2] = split_key(n=2, key) (arb k1, arb k2) -instance Arbitrary(Fin n) given (n) +instance Arbitrary(Fin n) given (n:Nat) def arb(key) = rand_idx key '## Ord on Arrays @@ -1888,7 +1878,7 @@ def search_sorted(xs:n=>a, x:a) -> Post n given (n|Ix, a|Ord) = else if x < xs[from_ordinal 0] then first_ix else - low <- with_state(0) + low <- with_state(0::Nat) high <- with_state(size n) _ <- iter numLeft = n_to_i (get high) - n_to_i (get low) @@ -1911,53 +1901,54 @@ def search_sorted_exact(xs:n=>a, x:a) -> Maybe n given (n|Ix, a|Ord) = '### min / max etc -def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x < f y, x, y) -def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a) = select(f x > f y, x, y) +def min_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a:Type) = select(f x < f y, x, y) +def max_by(f:(a)->o, x:a, y:a) -> a given (o|Ord, a:Type) = select(f x > f y, x, y) def min(x1: o, x2: o) -> o given (o|Ord) = min_by(id, x1, x2) def max(x1: o, x2: o) -> o given (o|Ord) = max_by(id, x1, x2) -def minimum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = - reduce(xs[0@_], \x y. min_by(f, x, y), xs) -def maximum_by(f:(a)->o, xs:n=>a) -> a given (a|Data, o|Ord, n|Ix) = - reduce(xs[0@_], \x y. max_by(f, x, y), xs) +def minimum_by(xs:n=>a, f:(a)->o) -> a given (a|Data, o|Ord, n|Ix) = + reduce(xs, xs[0@_], \x y. min_by(f, x, y)) +def maximum_by(xs:n=>a, f:(a)->o) -> a given (a|Data, o|Ord, n|Ix) = + reduce(xs, xs[0@_], \x y. max_by(f, x, y)) -def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(id, xs) -def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(id, xs) +def minimum(xs:n=>o) -> o given (n|Ix, o|Ord) = minimum_by(xs, id) +def maximum(xs:n=>o) -> o given (n|Ix, o|Ord) = maximum_by(xs, id) '### argmin/argmax --- TODO: put in same section as `searchsorted` +# TODO: put in same section as `searchsorted` -def argscan(comp:(o,o)->Bool, xs:n=>o) -> n given (o|Ord, n|Ix) = - zeroth = (0@_, xs[0@_]) - compare = \p1 p2. +def argscan(xs:n=>a, comp:(a,a)->Bool) -> n given (a|Ord, n|Ix) = + AccumTy : Type = (n, a) + zeroth : AccumTy = (0@_, xs[0@_]) + compare = \p1:AccumTy p2:AccumTy. (idx1, x1) = p1 (idx2, x2) = p2 select(comp(x1, x2), (idx1, x1), (idx2, x2)) - zipped = for i. (i, xs[i]) - fst $ reduce(zeroth, compare, zipped) + zipped = for i:n. (i, xs[i]) + fst $ reduce(zipped, zeroth, compare) -def argmin(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((<), xs) -def argmax(xs:n=>o) -> n given (n|Ix, o|Ord) = argscan((>), xs) +def argmin(xs:n=>a) -> n given (n|Ix, a|Ord) = argscan(xs, (<)) +def argmax(xs:n=>a) -> n given (n|Ix, a|Ord) = argscan(xs, (>)) def lexical_order( - compareElements:(n,n)->Bool, - compareLengths: (Nat,Nat)->Bool, xList:List n, - yList:List n + yList:List n, + compareElements:(n,n)->Bool, + compareLengths: (Nat,Nat)->Bool ) -> Bool given (n|Ord) = - -- Orders Lists according to the order of their elements, - -- in the same way a dictionary does. - -- For example, this lets us sort Strings. - -- - -- More precisely, it returns True iff compareElements xs.i ys.i is true - -- at the first location they differ. - -- - -- This function operates serially and short-circuits - -- at the first difference. One could also write this - -- function as a parallel reduction, but it would be - -- wasteful in the case where there is an early difference, - -- because we can't short circuit. + # Orders Lists according to the order of their elements, + # in the same way a dictionary does. + # For example, this lets us sort Strings. + # + # More precisely, it returns True iff compareElements xs.i ys.i is true + # at the first location they differ. + # + # This function operates serially and short-circuits + # at the first difference. One could also write this + # function as a parallel reduction, but it would be + # wasteful in the case where there is an early difference, + # because we can't short circuit. AsList(nx, xs) = xList AsList(ny, ys) = yList iter \i. @@ -1973,8 +1964,8 @@ def lexical_order( False -> Done False instance Ord(List n) given (n|Ord) - def (>)(xs, ys) = lexical_order((>), (>), xs, ys) - def (<)(xs, ys) = lexical_order((<), (<), xs, ys) + def (>)(xs, ys) = lexical_order(xs, ys, (>), (>)) + def (<)(xs, ys) = lexical_order(xs, ys, (<), (<)) '### clip @@ -1987,10 +1978,10 @@ TODO: these should be with the other Elementary/Special Functions ### atan/atan2 def atan_inner(x:Float) -> Float = - -- From "Computing accurate Horner form approximations to - -- special functions in finite precision arithmetic" - -- https://arxiv.org/abs/1508.03211 - -- Only accurate in the range [-1, 1] + # From "Computing accurate Horner form approximations to + # special functions in finite precision arithmetic" + # https://arxiv.org/abs/1508.03211 + # Only accurate in the range [-1, 1] s = x * x r = 0.0027856871 r = r * s - 0.0158660002 @@ -2005,13 +1996,13 @@ def atan_inner(x:Float) -> Float = def min_and_max(x:a, y:a) -> (a, a) given (a|Ord) = - select(x < y, (x, y), (y, x)) -- get both with one comparison. + select(x < y, (x, y), (y, x)) # get both with one comparison. def atan2(y:Float, x:Float) -> Float = - -- Based off of the Tensorflow implementation at - -- github.com/tensorflow/mlir-hlo/blob/master/lib/ - -- Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 - -- With a fix to the nan propagation. + # Based off of the Tensorflow implementation at + # github.com/tensorflow/mlir-hlo/blob/master/lib/ + # Dialect/mhlo/transforms/legalize_trigonometric_to_approximation.cc#L147 + # With a fix to the nan propagation. abs_x = abs x abs_y = abs y (min_abs_x_y, max_abs_x_y) = min_and_max(abs_x, abs_y) @@ -2021,9 +2012,9 @@ def atan2(y:Float, x:Float) -> Float = t = select(x < 0.0, pi, 0.0) a = select(y == 0.0, t, a) t = select(x < 0.0, 3.0 * pi / 4.0, pi / 4.0) - a = select(isinf x && isinf y, t, a) -- Handle infinite inputs. + a = select(isinf x && isinf y, t, a) # Handle infinite inputs. a = copysign(a, y) - select(either_is_nan(x, y), nan, a) -- Propagate NaNs. + select(either_is_nan(x, y), nan, a) # Propagate NaNs. def atan(x:Float) -> Float = atan2(x, 1.0) @@ -2033,13 +2024,13 @@ TODO: all of these should be in some other section def reflect(i:n) -> n given (n|Ix) = unsafe_from_ordinal $ unsafe_nat_diff(size n, ordinal i + 1) -def reverse(x:n=>a) -> n=>a given (n|Ix, a) = - for i. x[reflect i] +def reverse(x:n=>a) -> n=>a given (n|Ix, a:Type) = + for i:n. x[reflect i] def wrap_periodic(n|Ix, i:Nat) -> n = unsafe_from_ordinal(n=n, i `mod` size n) -def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a) = +def pad_to(m|Ix, x:a, xs:n=>a) -> m=>a given (n|Ix, a:Type) = n' = size n for i. i' = ordinal i @@ -2054,25 +2045,25 @@ def is_odd(x:Nat) -> Bool = rem(x, 2) == 1 def is_even(x:Nat) -> Bool = rem(x, 2) == 0 def is_power_of_2(x:Nat) -> Bool = - -- A fast trick based on bitwise AND. - -- This works on integer types larger than 8 bits. - -- Note: The bitwise and operator (.&.) - -- is only defined for Byte, which is why - -- we use %and here. TODO: Make (.&.) polymorphic. + # A fast trick based on bitwise AND. + # This works on integer types larger than 8 bits. + # Note: The bitwise and operator (.&.) + # is only defined for Byte, which is why + # we use %and here. TODO: Make (.&.) polymorphic. x' = nat_to_rep x if x' == 0 then False - else 0 == %and(x', (%isub(x', 1::NatRep))) - --- This computes the integer part of the binary logarithm of the input. --- TODO: natlog2 0 should do something other than underflow the answer. --- TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new --- code path in ImpToLLVM, because it's the first LLVM intrinsic --- we have with a fixed-point argument. --- https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic + else %and(x', (%isub(x', 1::NatRep))) == 0 + +# This computes the integer part of the binary logarithm of the input. +# TODO: natlog2 0 should do something other than underflow the answer. +# TODO: Use LLVM ctlz intrinsic instead. It needs a slightly new +# code path in ImpToLLVM, because it's the first LLVM intrinsic +# we have with a fixed-point argument. +# https://llvm.org/docs/LangRef.html#llvm-ctlz-intrinsic def natlog2(x:Nat) -> Nat = - tmp = yield_state 0 \ans. - cmp <- run_state 1 + tmp = yield_state (0::Nat) \ans. + cmp <- run_state (1::Nat) while \. if x >= (get cmp) then @@ -2081,7 +2072,7 @@ def natlog2(x:Nat) -> Nat = True else False - unsafe_nat_diff(tmp, 1) -- TODO: something less horrible + unsafe_nat_diff(tmp, 1) # TODO: something less horrible def nextpow2(x:Nat) -> Nat = case is_power_of_2 x of @@ -2093,10 +2084,10 @@ def general_integer_power( one:a, base:a, power:Nat ) -> a given (a|Data) = - iters = if power == 0 then 0 else 1 + natlog2 power - -- Implements exponentiation by squaring. - -- This could be nicer if there were a way to explicitly - -- specify which typelcass instance to use for Mul. + iters : Nat = if power == 0 then 0 else 1 + natlog2 power + # Implements exponentiation by squaring. + # This could be nicer if there were a way to explicitly + # specify which typelcass instance to use for Mul. yield_state one \ans. pow <- with_state power z <- with_state base @@ -2109,33 +2100,33 @@ def general_integer_power( def intpow(base:a, power:Nat) -> a given (a|Mul) = general_integer_power((*), one, base, power) -def from_just(x:Maybe a) -> a given (a) = case x of Just(x') -> x' +def from_just(x:Maybe a) -> a given (a:Type) = case x of Just(x') -> x' -def any_sat(f:(a)->Bool, xs:n=>a) -> Bool given (a, n|Ix) = any(each xs f) +def any_sat(xs:n=>a, f:(a)->Bool) -> Bool given (a:Type, n|Ix) = any(each xs f) -def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a) = - -- is it possible to implement this safely? (i.e. without using partial - -- functions) - case any_sat(is_nothing, xs) of +def seq_maybes(xs: n=>Maybe a) -> Maybe (n => a) given (n|Ix, a:Type) = + # is it possible to implement this safely? (i.e. without using partial + # functions) + case any_sat(xs, is_nothing) of True -> Nothing False -> Just $ each xs from_just def linear_search(xs:n=>a, query:a) -> Maybe n given (n|Ix, a|Eq) = - yield_state Nothing \ref. for i. + yield_state Nothing \ref. for i:n. case xs[i] == query of True -> ref := Just i False -> () -def list_length(l:List a) -> Nat given (a) = +def list_length(l:List a) -> Nat given (a:Type) = AsList(n, _) = l n --- This is for efficiency (rather than using `<>` repeatedly) --- TODO: we want this for any monoid but this implementation won't work. -def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = - totalSize = sum for i. list_length lists[i] - to_list $ with_state 0 \listIdx. - eltIdx <- with_state 0 +# This is for efficiency (rather than using `<>` repeatedly) +# TODO: we want this for any monoid but this implementation won't work. +def concat(lists:n=>(List a)) -> List a given (a:Type, n|Ix) = + totalSize = sum for i:n. list_length lists[i] + to_list $ with_state (0::Nat) \listIdx. + eltIdx <- with_state (0::Nat) for i:(Fin totalSize). while \. continue = get eltIdx >= list_length (lists[(get listIdx)@_]) @@ -2151,8 +2142,10 @@ def concat(lists:n=>(List a)) -> List a given (a, n|Ix) = xs[eltIdxVal@_] def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = - (num_res, res_inds) = yield_state (0::Nat, for i:n. Nothing) \ref. - for i. case xs[i] of + StateTy : Type = (Nat, n=>Maybe n) + init_state : StateTy = (0, for i. Nothing) + (num_res, res_inds) = yield_state init_state \ref. + for i:n. case xs[i] of Just(_) -> ix = get ref.0 ref.1 ! (unsafe_from_ordinal ix) := Just i @@ -2162,37 +2155,37 @@ def cat_maybes(xs:n=>Maybe a) -> List a given (n|Ix, a|Data) = case res_inds[unsafe_from_ordinal $ ordinal i] of Just(j) -> case xs[j] of Just(x) -> x - Nothing -> todo -- Impossible - Nothing -> todo -- Impossible + Nothing -> todo # Impossible + Nothing -> todo # Impossible def filter(xs:n=>a, condition:(a)->Bool) -> List a given (a|Data, n|Ix) = - cat_maybes $ for i. if condition xs[i] then Just xs[i] else Nothing + cat_maybes $ for i:n. if condition xs[i] then Just xs[i] else Nothing def arg_filter(xs:n=>a, condition:(a)->Bool) -> List n given (a|Data, n|Ix) = - cat_maybes $ for i. if condition xs[i] then Just i else Nothing + cat_maybes $ for i:n. if condition xs[i] then Just i else Nothing --- TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead +# TODO: use `ix_offset : [Ix n] -> n -> Int -> Maybe n` instead def prev_ix(i:n) -> Maybe n given (n|Ix) = case i_to_n (n_to_i (ordinal i) - 1) of Nothing -> Nothing Just(i_prev) -> unsafe_from_ordinal(i_prev) | Just def lines(source:String) -> List String = - AsList(_, s) = source - AsList(num_lines, newline_ixs) = cat_maybes for i_char. + AsList(num_chars, s) = source + AsList(num_lines, newline_ixs) = cat_maybes for i_char:(Fin num_chars). if s[i_char] == '\n' then Just(i_char) else Nothing to_list for i_line:(Fin num_lines). start = case prev_ix i_line of - Nothing -> first_ix Just(i) -> right_post newline_ixs[i] + Nothing -> first_ix end = left_post newline_ixs[i_line] post_slice(s, start, end) '## Probability --- cdf should include 0.0 but not 1.0 +# cdf should include 0.0 but not 1.0 def categorical_from_cdf(cdf: n=>Float, key: Key) -> n given (n|Ix) = r = rand key from_just $ left_fence $ search_sorted(cdf, r) @@ -2201,20 +2194,20 @@ def normalize_pdf(xs: d=>Float) -> d=>Float given (d|Ix) = xs / sum xs def cdf_for_categorical(logprobs: n=>Float) -> n=>Float given (n|Ix) = maxLogProb = maximum logprobs - cumsum_low $ normalize_pdf $ for i. exp(logprobs[i] - maxLogProb) + cumsum_low $ normalize_pdf $ for i:n. exp(logprobs[i] - maxLogProb) def categorical(logprobs: n=>Float, key: Key) -> n given (n|Ix) = categorical_from_cdf(cdf_for_categorical logprobs, key) --- batch variant to share the work of forming the cumsum --- (alternatively we could rely on hoisting of loop constants) +# batch variant to share the work of forming the cumsum +# (alternatively we could rely on hoisting of loop constants) def categorical_batch(logprobs: n=>Float, key: Key) -> m=>n given (n|Ix, m|Ix) = cdf = cdf_for_categorical logprobs - for i. categorical_from_cdf(cdf, ixkey(key, i)) + for i:m. categorical_from_cdf(cdf, ixkey(key, i)) def logsumexp(x: n=>Float) -> Float given (n|Ix) = m = maximum x - m + (log $ sum for i. exp (x[i] - m)) + m + (log $ sum for i:n. exp (x[i] - m)) def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = lse = logsumexp x @@ -2222,25 +2215,25 @@ def logsoftmax(x: n=>Float) -> n=>Float given (n|Ix) = def softmax(x: n=>Float) -> n=>Float given (n|Ix) = m = maximum x - e = for i. exp (x[i] - m) + e = for i:n. exp (x[i] - m) s = sum e for i. e[i] / s '## Polynomials TODO: Move this somewhere else -def evalpoly(coefficients:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = - -- Evaluate a polynomial at x. Same as Numpy's polyval. - fold zero \i c. coefficients[i] + x .* c +def evalpoly(coeffs:n=>v, x:Float) -> v given (n|Ix, v|VSpace) = + # Evaluate a polynomial at x. Same as Numpy's polyval. + fold zero coeffs \i coeff c. coeff + x .* c '## Exception effect --- TODO: move `error` and `todo` to here. +# TODO: move `error` and `todo` to here. -def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a, eff)= +def catch(f:() -> {Except|eff} a) -> {|eff} Maybe a given (a:Type, eff:Effects)= f' : (() -> {Except|eff} a) = \. f() %catchException(f') -def throw() -> {Except} a given (a) = +def throw() -> {Except} a given (a:Type) = %throwException(a) def assert(b:Bool) -> {Except} () = @@ -2270,22 +2263,22 @@ instance Subset(b, Either(a,b)) given (a|Data, b|Data) def int_to_reversed_digits(k:Nat) -> a=>b given (a|Ix, b|Ix) = base = size b - snd $ scan k \_ cur_k. + fst $ scan k (for i:a. ()) \_ _ cur_k. next_k = cur_k `idiv` base digit = cur_k `mod` base - (next_k, unsafe_from_ordinal(n=b, digit)) + (unsafe_from_ordinal(n=b, digit), next_k) def reversed_digits_to_int(digits: a=>b) -> Nat given (a|Ix, b|Ix) = base = size b - fst $ fold (0, 1) \j pair. + fst $ fold (0::Nat, 1::Nat) digits \j digit pair. (cur_k, cur_base) = pair - next_k = cur_k + ordinal digits[j] * cur_base + next_k = cur_k + ordinal digit * cur_base next_base = cur_base * base (next_k, next_base) instance Ix(a=>b) given (a|Ix, b|Ix) - -- 0@a is the least significant digit, - -- while (size a - 1)@a is the most significant digit. + # 0@a is the least significant digit, + # while (size a - 1)@a is the most significant digit. def size'() = size b `intpow` size a def ordinal(i) = reversed_digits_to_int i def unsafe_from_ordinal(i) = int_to_reversed_digits i @@ -2352,18 +2345,19 @@ struct Stack(h:Heap, a|Data) = self.size_ref := n_new Just $ get buf!(unsafe_from_ordinal n_new) -stack_init_size = 16 +stack_init_size : Nat = 16 + def with_stack( a|Data, action:(given (h:Heap), Stack(h, a)) -> {State h|eff} r - ) -> {|eff} r given (eff, r) = - init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() - with_state (0, init_stack) \ref . action(Stack(ref.0, ref.1)) + ) -> {|eff} r given (eff:Effects, r:Type) = + init_stack = to_list for i:(Fin stack_init_size). uninitialized_value() :: a + with_state (0::Nat, init_stack) \ref . action(Stack(ref.0, ref.1)) -def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n, h) = +def stack_extend_internal(stack:Stack(h, Char), x:Fin n=>Char) -> {State h} () given (n:Nat, h:Heap) = stack.extend(x) -def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h) = +def stack_push_internal(stack:Stack(h, Char), x:Char) -> {State h} () given (h:Heap) = stack.push(x) def with_stack_internal(f:(given (h:Heap), Stack(h, Char)) -> {State h} ()) -> List Char = @@ -2387,7 +2381,7 @@ def from_c_string(s:CString) -> {IO} (Maybe String) = stack.push(c) Continue -def show_any(x:a) -> String given (a) = unsafe_coerce(to=String, %showAny(x)) +def show_any(x:a) -> String given (a:Type) = unsafe_coerce(to=String, %showAny(x)) def coerce_table(m|Ix, x:n=>a) -> m => a given (n|Ix, a|Data) = if size m == size n @@ -2407,12 +2401,12 @@ def check_env(name:String) -> {IO} Bool = '## Testing Helpers --- -- Reliably causes a segfault if pointers aren't initialized to zero. --- -- TODO: add this test when we cache modules --- justSomeDataToTestCaching = toList for i:(Fin 100). --- if ordinal i == 0 --- then Left (toList [1,2,3]) --- else Right 1 +# # Reliably causes a segfault if pointers aren't initialized to zero. +# # TODO: add this test when we cache modules +# justSomeDataToTestCaching = toList for i:(Fin 100). +# if ordinal i == 0 +# then Left (toList [1,2,3]) +# else Right 1 '### TestMode @@ -2422,20 +2416,19 @@ def dex_test_mode() -> Bool = unsafe_io \. check_env "DEX_TEST_MODE" '### More Stream IO def fread(stream:Stream ReadMode) -> {IO} String = - -- TODO: allow reading longer files! - n = 4096 - ptr:(Ptr Char) <- with_alloc n + # TODO: allow reading longer files! + n : Nat = 4096 + ptr <- with_alloc(Char, n) stack <- with_stack Char iter \_. numRead = i_to_w32 $ i64_to_i $ freadFFI(ptr.val, 1, n_to_i64 n, stream.ptr) AsList(_, new_chars) = string_from_char_ptr(numRead, ptr) stack.extend(new_chars) - if numRead == n_to_w32 n - then Continue - else Done () + case numRead == n_to_w32 n of + True -> Continue :: IterResult () + False -> Done () stack.read() - '### Shelling Out foreign "popen" popenFFI : (RawPtr, RawPtr) -> {IO} RawPtr @@ -2447,8 +2440,7 @@ def shell_out(command:String) -> {IO} String = modeStr = "r" with_c_string command \command'. with_c_string modeStr \modeStr'. - pipe = Stream $ popenFFI(command'.ptr, modeStr'.ptr) - fread pipe + fread $ Stream $ popenFFI(command'.ptr, modeStr'.ptr) '### File Operations @@ -2491,16 +2483,16 @@ def new_temp_file() -> {IO} FilePath = closeFFI fd string_from_char_ptr(15, (Ptr s.ptr)) -def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a) = +def with_temp_file(action: (FilePath) -> {IO} a) -> {IO} a given (a:Type) = tmpFile = new_temp_file() result = action tmpFile delete_file tmpFile result -def with_temp_files(action: (n=>FilePath) -> {IO} a) -> {IO} a given (n|Ix, a) = - tmpFiles = for i. new_temp_file() +def with_temp_files(n|Ix, action: (n=>FilePath) -> {IO} a) -> {IO} a given (a:Type) = + tmpFiles = for i:n. new_temp_file() result = action tmpFiles - for i. delete_file tmpFiles[i] + each tmpFiles delete_file result '### Linear Algebra @@ -2509,16 +2501,16 @@ def linspace(n|Ix, low:Float, high:Float) -> n=>Float = dx = (high - low) / n_to_f (size n) for i:n. low + n_to_f (ordinal i) * dx -def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a) = for i j. x[j,i] -def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i. x[i] * y[i] -def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j. s[j] .* vs[j] +def transpose(x:n=>m=>a) -> m=>n=>a given (n|Ix, m|Ix, a:Type) = for i j. x[j,i] +def vdot(x:n=>Float, y:n=>Float) -> Float given (n|Ix) = fsum for i:n. x[i] * y[i] +def dot(s:n=>Float, vs:n=>v) -> v given (n|Ix, v|VSpace) = sum for j:n. s[j] .* vs[j] def naive_matmul(x: l=>m=>Float, y: m=>n=>Float) -> (l=>n=>Float) given (l|Ix, m|Ix, n|Ix) = - for i k. fsum for j. x[i,j] * y[j,k] + for i k. fsum for j:m. x[i,j] * y[j,k] --- A `FullTileIx` type represents `tile_ix`th full tile (of size --- `tile_size`) iterating over the index set `n`. --- This type is only well formed when tile_ix * tile_size < size n. +# A `FullTileIx` type represents `tile_ix`th full tile (of size +# `tile_size`) iterating over the index set `n`. +# This type is only well formed when tile_ix * tile_size < size n. struct FullTileIx(n|Ix, tile_size:Nat, tile_ix:Nat) = unwrap : Fin tile_size @@ -2532,9 +2524,9 @@ instance Subset(FullTileIx(n, tile_size, tile_ix), n) given (n|Ix, tile_size:Nat def project'(i) = todo def unsafe_project'(i) = todo --- A `CodaIx` type represents the last few elements of the index set `n`, --- as might be left over after iterating by tiles. --- This type is only well formed when size n == coda_offset + coda_size +# A `CodaIx` type represents the last few elements of the index set `n`, +# as might be left over after iterating by tiles. +# This type is only well formed when size n == coda_offset + coda_size struct CodaIx(n|Ix, coda_offset:Nat, coda_size:Nat) = unwrap : Fin coda_size @@ -2552,7 +2544,7 @@ def tile( n|Ix, tile_size: Nat, body:(m:Type, given () (Ix m, Subset(m, n))) -> {|eff} () - ) -> {|eff} () given (eff) = + ) -> {|eff} () given (eff:Effects) = num_tiles = size n `idiv` tile_size coda_size = size n `rem` tile_size coda_offset = num_tiles * tile_size @@ -2566,10 +2558,10 @@ def tiled_matmul( x: l=>m=>Float, y: m=>n=>Float ) -> l=>n=>Float given (l|Ix, m|Ix, n|Ix) = - -- Tile sizes picked for axch's laptop - l_tile_size = 32 - n_tile_size = 128 - m_tile_size = 8 + # Tile sizes picked for axch's laptop + l_tile_size : Nat = 32 + n_tile_size : Nat = 128 + m_tile_size : Nat = 8 yield_accum (AddMonoid Float) \result. tile(l, l_tile_size) \l_set. tile(n, n_tile_size) \n_set. @@ -2577,12 +2569,12 @@ def tiled_matmul( for_ l_offset:l_set. l_ix = inject(to=l, l_offset) for_ m_offset:m_set. - m_ix = inject m_offset + m_ix = inject(to=m, m_offset) for_ n_offset:n_set. - n_ix = inject n_offset + n_ix = inject(to=n, n_offset) result!l_ix!n_ix += x[l_ix][m_ix] * y[m_ix][n_ix] --- matmul. Better symbol to use? `@`? +# matmul. Better symbol to use? `@`? def (**)( x: l=>m=>Float, y: m=>n=>Float @@ -2592,8 +2584,9 @@ def (**)( def matmul_linearization( x: l=>m=>Float, y: m=>n=>Float - ) -> _ given (l|Ix, m|Ix, n|Ix) = - def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> _ = + ) -> (l=>n=>Float, (l=>m=>Float, m=>n=>Float)->l=>n=>Float) + given (l|Ix, m|Ix, n|Ix) = + def lin(xt: l=>m=>Float, yt: m=>n=>Float) -> l=>n=>Float = x ** yt + xt ** y (x ** y, lin) @@ -2605,7 +2598,7 @@ def(.**)(v: n=>Float, mat: n=>m=>Float) -> (m=>Float) given (n|Ix, m|Ix) = transpose mat **. v def inner(x:n=>Float, mat:n=>m=>Float, y:m=>Float) -> Float given (n|Ix, m|Ix) = - fsum for p. + fsum for p:(n,m). (i,j) = p x[i] * mat[i,j] * y[j] diff --git a/misc/dex.el b/misc/dex.el index bc56b2934..de892c45e 100644 --- a/misc/dex.el +++ b/misc/dex.el @@ -5,7 +5,7 @@ ;; https://developers.google.com/open-source/licenses/bsd (setq dex-highlights - `(("--\\([^o].*$\\|$\\)" . font-lock-comment-face) + `(("#.*$" . font-lock-comment-face) ("^> .*$" . font-lock-comment-face) ("^'\\(.\\|\n.\\)*\n\n" . font-lock-comment-face) ("\\w+:" . font-lock-comment-face) @@ -18,7 +18,6 @@ "\\bwith\\b\\|\\bself\\b\\|" "\\bimport\\b\\|\\bforeign\\b\\|\\bsatisfying\\b") . font-lock-keyword-face) - ("--o" . font-lock-variable-name-face) ("[-.,!;$^&*:~+/=<>|?\\\\]" . font-lock-variable-name-face) ("\\b[[:upper:]][[:alnum:]]*\\b" . font-lock-type-face) ("^@[[:alnum:]]*\\b" . font-lock-keyword-face) @@ -36,7 +35,7 @@ (define-derived-mode dex-mode fundamental-mode "dex" (setq font-lock-defaults '(dex-highlights)) - (setq-local comment-start "--") + (setq-local comment-start "#") (setq-local comment-end "") (setq-local syntax-propertize-function (syntax-propertize-rules (".>\\( +\\)" (1 ".")))) diff --git a/src/dex.hs b/src/dex.hs index 20e90a885..f8bce6475 100644 --- a/src/dex.hs +++ b/src/dex.hs @@ -21,25 +21,25 @@ import Data.List import qualified Data.Text as T import qualified Data.Text.Encoding as T import qualified Data.Map.Strict as M +import qualified System.Console.ANSI as ANSI +import System.Console.ANSI hiding (Color) -import PPrint (toJSONStr, printResult) import TopLevel import Err import Name import AbstractSyntax (parseTopDeclRepl) import ConcreteSyntax (keyWordStrs, preludeImportBlock) -#ifdef DEX_LIVE import RenderHtml -import Live.Terminal (runTerminal) +-- import Live.Terminal (runTerminal) import Live.Web (runWeb) -#endif +import PPrint hiding (hardline) import Core import Types.Core import Types.Imp -import Types.Misc import Types.Source +import Types.Top +import MonadUtil -data ErrorHandling = HaltOnErr | ContinueOnErr data DocFmt = ResultOnly | TextDoc | JSONDoc @@ -58,40 +58,26 @@ data EvalMode = ReplMode String data CmdOpts = CmdOpts EvalMode EvalConfig runMode :: EvalMode -> EvalConfig -> IO () -runMode evalMode opts = case evalMode of +runMode evalMode cfg = case evalMode of ScriptMode fname fmt onErr -> do env <- loadCache - (litProg, finalEnv) <- runTopperM opts env do + ((), finalEnv) <- runTopperM cfg env do source <- liftIO $ T.decodeUtf8 <$> BS.readFile fname - evalSourceText source (printIncrementalSource fmt) \result@(Result _ errs) -> do - printIncrementalResult fmt result - return case (onErr, errs) of (HaltOnErr, Failure _) -> False; _ -> True - printFinal fmt litProg + evalSourceText source $ printIncrementalSource fmt storeCache finalEnv ReplMode prompt -> do env <- loadCache - void $ runTopperM opts env do - evalBlockRepl preludeImportBlock + void $ runTopperM cfg env do + evalSourceBlockRepl preludeImportBlock forever do block <- readSourceBlock prompt - evalBlockRepl block - where - evalBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n () - evalBlockRepl block = do - result <- evalSourceBlockRepl block - case result of - Result [] (Success ()) -> return () - _ -> liftIO $ putStrLn $ pprint result + evalSourceBlockRepl block ClearCache -> clearCache #ifdef DEX_LIVE - -- These are broken if the prelude produces any arrays because the blockId - -- counter restarts at zero. TODO: make prelude an implicit import block WebMode fname -> do env <- loadCache - runWeb fname opts env - WatchMode fname -> do - env <- loadCache - runTerminal fname opts env + runWeb fname cfg env + WatchMode _ -> error "not implemented" #endif printIncrementalSource :: DocFmt -> SourceBlock -> IO () @@ -103,26 +89,6 @@ printIncrementalSource fmt sb = case fmt of HTMLDoc -> return () #endif -printIncrementalResult :: DocFmt -> Result -> IO () -printIncrementalResult fmt result = case fmt of - ResultOnly -> case pprint result of [] -> return (); msg -> putStrLn msg - JSONDoc -> case toJSONStr result of "{}" -> return (); s -> putStrLn s - TextDoc -> do - isatty <- queryTerminal stdOutput - putStr $ printResult isatty result -#ifdef DEX_LIVE - HTMLDoc -> return () -#endif - -printFinal :: DocFmt -> [(SourceBlock, Result)] -> IO () -printFinal fmt prog = case fmt of - ResultOnly -> return () - TextDoc -> return () - JSONDoc -> return () -#ifdef DEX_LIVE - HTMLDoc -> putStr $ progHtml prog -#endif - readSourceBlock :: (MonadIO (m n), EnvReader m) => String -> m n SourceBlock readSourceBlock prompt = do sourceMap <- withEnv $ envSourceMap . moduleEnv @@ -210,24 +176,43 @@ parseEvalOpts = EvalConfig <*> (option pathOption $ long "lib-path" <> value [LibBuiltinPath] <> metavar "PATH" <> help "Library path") <*> optional (strOption $ long "prelude" <> metavar "FILE" <> help "Prelude file") - <*> optional (strOption $ long "logto" - <> metavar "FILE" - <> help "File to log to" <> showDefault) - <*> pure Nothing <*> flag NoOptimize Optimize (short 'O' <> help "Optimize generated code") <*> enumOption "print" "Print backend" PrintCodegen printBackends + <*> flag ContinueOnErr HaltOnErr ( + long "stop-on-error" + <> help "Stop program evaluation when an error occurs (type or runtime)") + <*> enumOption "loglevel" "Log level" NormalLogLevel logLevels + <*> pure stdOutLogger where printBackends = [ ("haskell", PrintHaskell) , ("dex" , PrintCodegen) ] - backends = [ ("llvm", LLVM) - , ("llvm-mc", LLVMMC) -#ifdef DEX_CUDA - , ("llvm-cuda", LLVMCUDA) -#endif -#if DEX_LLVM_VERSION == HEAD - , ("mlir", MLIR) -#endif - , ("interpreter", Interpreter)] + backends = [ ("llvm" , LLVM ) + , ("llvm-mc", LLVMMC) ] + logLevels = [ ("normal", NormalLogLevel) + , ("debug" , DebugLogLevel ) ] + +stdOutLogger :: Outputs -> IO () +stdOutLogger (Outputs outs) = do + isatty <- queryTerminal stdOutput + forM_ outs \out -> putStr $ printOutput isatty out + +printOutput :: Bool -> Output -> String +printOutput isatty out = case out of + Error _ -> addColor isatty Red $ addPrefix ">" $ pprint out + _ -> addPrefix (addColor isatty Cyan ">") $ pprint $ out + +addPrefix :: String -> String -> String +addPrefix prefix str = unlines $ map prefixLine $ lines str + where prefixLine :: String -> String + prefixLine s = case s of "" -> prefix + _ -> prefix ++ " " ++ s + +addColor :: Bool -> ANSI.Color -> String -> String +addColor False _ s = s +addColor True c s = + setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] + ++ s ++ setSGRCode [Reset] + pathOption :: ReadM [LibPath] pathOption = splitPaths [] <$> str @@ -242,13 +227,7 @@ pathOption = splitPaths [] <$> str "BUILTIN_LIBRARIES" -> LibBuiltinPath path -> LibDirectory path -openLogFile :: EvalConfig -> IO EvalConfig -openLogFile EvalConfig {..} = do - logFile <- forM logFileName (`openFile` WriteMode) - return $ EvalConfig {..} - main :: IO () main = do CmdOpts evalMode opts <- execParser parseOpts - opts' <- openLogFile opts - runMode evalMode opts' + runMode evalMode opts diff --git a/src/lib/AbstractSyntax.hs b/src/lib/AbstractSyntax.hs index e2143fd9a..e7a396ce6 100644 --- a/src/lib/AbstractSyntax.hs +++ b/src/lib/AbstractSyntax.hs @@ -53,37 +53,35 @@ import Data.Functor import Data.Either import Data.Maybe (catMaybes) import Data.Set qualified as S -import Data.String (fromString) import Data.Text (Text) import ConcreteSyntax import Err import Name -import PPrint () +import PPrint import Types.Primitives -import SourceInfo import Types.Source import qualified Types.OpNames as P import Util -- === Converting concrete syntax to abstract syntax === -parseExpr :: Fallible m => Group -> m (UExpr VoidS) +parseExpr :: Fallible m => GroupW -> m (UExpr VoidS) parseExpr e = liftSyntaxM $ expr e -parseDecl :: Fallible m => CTopDecl -> m (UTopDecl VoidS VoidS) +parseDecl :: Fallible m => CTopDeclW -> m (UTopDecl VoidS VoidS) parseDecl d = liftSyntaxM $ topDecl d parseBlock :: Fallible m => CSBlock -> m (UBlock VoidS) parseBlock b = liftSyntaxM $ block b liftSyntaxM :: Fallible m => SyntaxM a -> m a -liftSyntaxM cont = liftExcept $ runFallibleM cont +liftSyntaxM cont = liftExcept cont parseTopDeclRepl :: Text -> Maybe SourceBlock parseTopDeclRepl s = case sbContents b of UnParseable True _ -> Nothing - _ -> case runFallibleM (checkSourceBlockParses $ sbContents b) of + _ -> case checkSourceBlockParses $ sbContents b of Success _ -> Just b Failure _ -> Nothing where b = mustParseSourceBlock s @@ -91,7 +89,7 @@ parseTopDeclRepl s = case sbContents b of checkSourceBlockParses :: SourceBlock' -> SyntaxM () checkSourceBlockParses = \case - TopDecl (WithSrc _ (CSDecl ann (CExpr e)))-> do + TopDecl (WithSrcs _ _ (CSDecl ann (CExpr e)))-> do when (ann /= PlainLet) $ fail "Cannot annotate expressions" void $ expr e TopDecl d -> void $ topDecl d @@ -103,262 +101,291 @@ checkSourceBlockParses = \case -- === Converting concrete syntax to abstract syntax === -type SyntaxM = FallibleM +type SyntaxM = Except -topDecl :: CTopDecl -> SyntaxM (UTopDecl VoidS VoidS) -topDecl = dropSrc topDecl' where - topDecl' (CSDecl ann d) = ULocalDecl <$> decl ann (WithSrc emptySrcPosCtx d) - topDecl' (CData name tyConParams givens constructors) = do - tyConParams' <- aExplicitParams tyConParams +topDecl :: CTopDeclW -> SyntaxM (UTopDecl VoidS VoidS) +topDecl (WithSrcs sid sids topDecl') = case topDecl' of + CSDecl ann d -> ULocalDecl <$> decl ann (WithSrcs sid sids d) + CData name tyConParams givens constructors -> do + tyConParams' <- fromMaybeM tyConParams Empty aExplicitParams givens' <- aOptGivens givens constructors' <- forM constructors \(v, ps) -> do - ps' <- toNest <$> mapM tyOptBinder ps + ps' <- fromMaybeM ps Empty \(WithSrcs _ _ ps') -> + toNest <$> mapM (tyOptBinder Explicit) ps' return (v, ps') return $ UDataDefDecl - (UDataDef name (catUOptAnnExplBinders givens' tyConParams') $ - map (\(name', cons) -> (name', UDataDefTrail cons)) constructors') - (fromString name) - (toNest $ map (fromString . fst) constructors') - topDecl' (CStruct name params givens fields defs) = do - params' <- aExplicitParams params + (UDataDef (withoutSrc name) (givens' >>> tyConParams') $ + map (\(name', cons) -> (withoutSrc name', UDataDefTrail cons)) constructors') + (fromSourceNameW name) + (toNest $ map (fromSourceNameW . fst) constructors') + CStruct name params givens fields defs -> do + params' <- fromMaybeM params Empty aExplicitParams givens' <- aOptGivens givens fields' <- forM fields \(v, ty) -> (v,) <$> expr ty methods <- forM defs \(ann, d) -> do - (methodName, lam) <- aDef d - return (ann, methodName, Abs (UBindSource emptySrcPosCtx "self") lam) - return $ UStructDecl (fromString name) (UStructDef name (catUOptAnnExplBinders givens' params') fields' methods) - topDecl' (CInterface name params methods) = do + (WithSrc _ methodName, lam) <- aDef d + return (ann, methodName, Abs (WithSrcB sid (UBindSource "self")) lam) + return $ UStructDecl (fromSourceNameW name) (UStructDef (withoutSrc name) (givens' >>> params') fields' methods) + CInterface name params methods -> do params' <- aExplicitParams params (methodNames, methodTys) <- unzip <$> forM methods \(methodName, ty) -> do ty' <- expr ty - return (fromString methodName, ty') - return $ UInterface params' methodTys (fromString name) (toNest methodNames) - topDecl' (CInstanceDecl def) = aInstanceDef def - topDecl' (CEffectDecl _ _) = error "not implemented" - topDecl' (CHandlerDecl _ _ _ _ _ _) = error "not implemented" - -decl :: LetAnn -> CSDecl -> SyntaxM (UDecl VoidS VoidS) -decl ann = propagateSrcB \case + return (fromSourceNameW methodName, ty') + return $ UInterface params' methodTys (fromSourceNameW name) (toNest methodNames) + CInstanceDecl def -> aInstanceDef def + +decl :: LetAnn -> CSDeclW -> SyntaxM (UDecl VoidS VoidS) +decl ann (WithSrcs sid _ d) = WithSrcB sid <$> case d of CLet binder rhs -> do (p, ty) <- patOptAnn binder ULet ann p ty <$> asExpr <$> block rhs - CBind _ _ -> throw SyntaxErr "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope." + CBind _ _ -> throw sid TopLevelArrowBinder CDefDecl def -> do (name, lam) <- aDef def - return $ ULet ann (fromString name) Nothing (ns $ ULam lam) + return $ ULet ann (fromSourceNameW name) Nothing (WithSrcE sid (ULam lam)) CExpr g -> UExprDecl <$> expr g CPass -> return UPass aInstanceDef :: CInstanceDef -> SyntaxM (UTopDecl VoidS VoidS) -aInstanceDef (CInstanceDef clName args givens methods instNameAndParams) = do - let clName' = fromString clName +aInstanceDef (CInstanceDef (WithSrc clNameId clName) args givens methods instNameAndParams) = do + let clName' = SourceName clNameId clName args' <- mapM expr args givens' <- aOptGivens givens methods' <- catMaybes <$> mapM aMethod methods case instNameAndParams of Nothing -> return $ UInstance clName' givens' args' methods' NothingB ImplicitApp - Just (instName, optParams) -> do - let instName' = JustB $ fromString instName + Just (WithSrc sid instName, optParams) -> do + let instName' = JustB $ WithSrcB sid $ UBindSource instName case optParams of Just params -> do params' <- aExplicitParams params - return $ UInstance clName' (catUOptAnnExplBinders givens' params') args' methods' instName' ExplicitApp + return $ UInstance clName' (givens' >>> params') args' methods' instName' ExplicitApp Nothing -> return $ UInstance clName' givens' args' methods' instName' ImplicitApp -aDef :: CDef -> SyntaxM (SourceName, ULamExpr VoidS) +aDef :: CDef -> SyntaxM (SourceNameW, ULamExpr VoidS) aDef (CDef name params optRhs optGivens body) = do - explicitParams <- aExplicitParams params + explicitParams <- explicitBindersOptAnn params let rhsDefault = (ExplicitApp, Nothing, Nothing) (expl, effs, resultTy) <- fromMaybeM optRhs rhsDefault \(expl, optEffs, resultTy) -> do effs <- fromMaybeM optEffs UPure aEffects resultTy' <- expr resultTy return (expl, Just effs, Just resultTy') implicitParams <- aOptGivens optGivens - let allParams = catUOptAnnExplBinders implicitParams explicitParams + let allParams = implicitParams >>> explicitParams body' <- block body return (name, ULamExpr allParams expl effs resultTy body') -catUOptAnnExplBinders :: UOptAnnExplBinders n l -> UOptAnnExplBinders l l' -> UOptAnnExplBinders n l' -catUOptAnnExplBinders (expls, bs) (expls', bs') = (expls <> expls', bs >>> bs') - -stripParens :: Group -> Group -stripParens (WithSrc _ (CParens [g])) = stripParens g +stripParens :: GroupW -> GroupW +stripParens (WithSrcs _ _ (CParens [g])) = stripParens g stripParens g = g -aExplicitParams :: ExplicitParams -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) -aExplicitParams gs = generalBinders DataParam Explicit gs - -aOptGivens :: Maybe GivenClause -> SyntaxM (UOptAnnExplBinders VoidS VoidS) -aOptGivens optGivens = do - (expls, implicitParams) <- unzip <$> fromMaybeM optGivens [] aGivens - return (expls, toNest implicitParams) - -aGivens :: GivenClause -> SyntaxM [(Explicitness, UOptAnnBinder VoidS VoidS)] -aGivens (implicits, optConstraints) = do - implicits' <- mapM (generalBinder DataParam (Inferred Nothing Unify)) implicits - constraints <- fromMaybeM optConstraints [] \gs -> do - mapM (generalBinder TypeParam (Inferred Nothing (Synth Full))) gs - return $ implicits' <> constraints - -generalBinders - :: ParamStyle -> Explicitness -> [Group] - -> SyntaxM ([Explicitness], Nest UOptAnnBinder VoidS VoidS) -generalBinders paramStyle expl params = do - (expls, bs) <- unzip . concat <$> forM params \case - WithSrc _ (CGivens gs) -> aGivens gs - p -> (:[]) <$> generalBinder paramStyle expl p - return (expls, toNest bs) - -generalBinder :: ParamStyle -> Explicitness -> Group - -> SyntaxM (Explicitness, UOptAnnBinder VoidS VoidS) -generalBinder paramStyle expl g = case expl of - Inferred _ (Synth _) -> (expl,) <$> tyOptBinder g - Inferred _ Unify -> do - b <- binderOptTy g - expl' <- return case b of - UAnnBinder (UBindSource _ s) _ _ -> Inferred (Just s) Unify - _ -> expl - return (expl', b) - Explicit -> (expl,) <$> case paramStyle of - TypeParam -> tyOptBinder g - DataParam -> binderOptTy g - - -- Binder pattern with an optional type annotation -patOptAnn :: Group -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) -patOptAnn (Binary Colon lhs typeAnn) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) -patOptAnn (WithSrc _ (CParens [g])) = patOptAnn g +-- === combinators for different sorts of binder lists === + +aOptGivens :: Maybe GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aOptGivens optGivens = fromMaybeM optGivens Empty aGivens + +binderList + :: [GroupW] -> (GroupW -> SyntaxM (Nest UAnnBinder VoidS VoidS)) + -> SyntaxM (Nest UAnnBinder VoidS VoidS) +binderList gs cont = concatNests <$> forM gs \case + WithSrcs _ _ (CGivens gs') -> aGivens gs' + g -> cont g + +withTrailingConstraints + :: GroupW -> (GroupW -> SyntaxM (UAnnBinder VoidS VoidS)) + -> SyntaxM (Nest UAnnBinder VoidS VoidS) +withTrailingConstraints g cont = case g of + WithSrcs _ _ (CBin (WithSrc _ Pipe) lhs c) -> do + Nest (UAnnBinder expl (WithSrcB sid b) ann cs) bs <- withTrailingConstraints lhs cont + s <- case b of + UBindSource s -> return s + UIgnore -> throw sid CantConstrainAnonBinders + UBind _ _ -> error "Shouldn't have internal names until renaming pass" + c' <- expr c + return $ UnaryNest (UAnnBinder expl (WithSrcB sid b) ann (cs ++ [c'])) + >>> bs + >>> UnaryNest (asConstraintBinder (mkUVar sid s) c') + _ -> UnaryNest <$> cont g + where + asConstraintBinder :: UExpr VoidS -> UConstraint VoidS -> UAnnBinder VoidS VoidS + asConstraintBinder v c = do + let sid = srcPos c + let t = WithSrcE sid (UApp c [v] []) + UAnnBinder (Inferred Nothing (Synth Full)) (WithSrcB sid UIgnore) (UAnn t) [] + +mkUVar :: SrcId -> SourceName -> UExpr VoidS +mkUVar sid v = WithSrcE sid $ UVar $ SourceName sid v + +aGivens :: GivenClause -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aGivens ((WithSrcs _ _ implicits), optConstraints) = do + implicits' <- concatNests <$> forM implicits \b -> withTrailingConstraints b implicitArgBinder + constraints <- fromMaybeM optConstraints Empty (\(WithSrcs _ _ gs) -> toNest <$> mapM synthBinder gs) + return $ implicits' >>> constraints + +synthBinder :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +synthBinder g = tyOptBinder (Inferred Nothing (Synth Full)) g + +concatNests :: [Nest b VoidS VoidS] -> Nest b VoidS VoidS +concatNests [] = Empty +concatNests (b:bs) = b >>> concatNests bs + +implicitArgBinder :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +implicitArgBinder g = do + UAnnBinder _ b ann cs <- binderOptTy (Inferred Nothing Unify) g + s <- case b of + WithSrcB _ (UBindSource s) -> return $ Just s + _ -> return Nothing + return $ UAnnBinder (Inferred s Unify) b ann cs + +aExplicitParams :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aExplicitParams (WithSrcs _ _ bs) = binderList bs \b -> withTrailingConstraints b \b' -> + binderOptTy Explicit b' + +aPiBinders :: [GroupW] -> SyntaxM (Nest UAnnBinder VoidS VoidS) +aPiBinders bs = binderList bs \b -> + UnaryNest <$> tyOptBinder Explicit b + +explicitBindersOptAnn :: ExplicitParams -> SyntaxM (Nest UAnnBinder VoidS VoidS) +explicitBindersOptAnn (WithSrcs _ _ bs) = + binderList bs \b -> withTrailingConstraints b \b' -> binderOptTy Explicit b' + +-- === + +-- Binder pattern with an optional type annotation +patOptAnn :: GroupW -> SyntaxM (UPat VoidS VoidS, Maybe (UType VoidS)) +patOptAnn (WithSrcs _ _ (CBin (WithSrc _ Colon) lhs typeAnn)) = (,) <$> pat lhs <*> (Just <$> expr typeAnn) +patOptAnn (WithSrcs _ _ (CParens [g])) = patOptAnn g patOptAnn g = (,Nothing) <$> pat g -uBinder :: Group -> SyntaxM (UBinder c VoidS VoidS) -uBinder (WithSrc src b) = addSrcContext src $ case b of - CIdentifier name -> return $ fromString name - CHole -> return UIgnore - _ -> throw SyntaxErr "Binder must be an identifier or `_`" +uBinder :: GroupW -> SyntaxM (UBinder c VoidS VoidS) +uBinder (WithSrcs sid _ b) = case b of + CLeaf (CIdentifier name) -> return $ fromSourceNameW $ WithSrc sid name + CLeaf CHole -> return $ WithSrcB sid UIgnore + _ -> throw sid UnexpectedBinder -- Type annotation with an optional binder pattern -tyOptPat :: Group -> SyntaxM (UOptAnnBinder VoidS VoidS) -tyOptPat = \case +tyOptPat :: GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +tyOptPat grpTop@(WithSrcs sid _ grp) = case grp of -- Named type - Binary Colon lhs typeAnn -> UAnnBinder <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] + CBin (WithSrc _ Colon) lhs typeAnn -> + UAnnBinder Explicit <$> uBinder lhs <*> (UAnn <$> expr typeAnn) <*> pure [] -- Binder in grouping parens. - WithSrc _ (CParens [g]) -> tyOptPat g + CParens [g] -> tyOptPat g -- Anonymous type - g -> UAnnBinder UIgnore <$> (UAnn <$> expr g) <*> pure [] + _ -> UAnnBinder Explicit (WithSrcB sid UIgnore) <$> (UAnn <$> expr grpTop) <*> pure [] -- Pattern of a case binder. This treats bare names specially, in -- that they become (nullary) constructors to match rather than names -- to bind. -casePat :: Group -> SyntaxM (UPat VoidS VoidS) +casePat :: GroupW -> SyntaxM (UPat VoidS VoidS) casePat = \case - (WithSrc src (CIdentifier name)) -> return $ WithSrcB src $ UPatCon (fromString name) Empty + WithSrcs src _ (CLeaf (CIdentifier name)) -> + return $ WithSrcB src $ UPatCon (fromSourceNameW (WithSrc src name)) Empty g -> pat g -pat :: Group -> SyntaxM (UPat VoidS VoidS) -pat = propagateSrcB pat' where - pat' (CBin (WithSrc _ DepComma) lhs rhs) = do +pat :: GroupW -> SyntaxM (UPat VoidS VoidS) +pat (WithSrcs sid _ grp) = WithSrcB sid <$> case grp of + CBin (WithSrc _ DepComma) lhs rhs -> do lhs' <- pat lhs rhs' <- pat rhs return $ UPatDepPair $ PairB lhs' rhs' - pat' (CBrackets gs) = UPatTable . toNest <$> (mapM pat gs) + CBrackets gs -> UPatTable . toNest <$> (mapM pat gs) -- TODO: use Python-style trailing comma (like `(x,y,)`) for singleton tuples - pat' (CParens [g]) = dropSrcB <$> casePat g - pat' (CParens gs) = UPatProd . toNest <$> mapM pat gs - pat' CHole = return $ UPatBinder UIgnore - pat' (CIdentifier name) = return $ UPatBinder $ fromString name - pat' (CBin (WithSrc _ JuxtaposeWithSpace) lhs rhs) = do + CParens gs -> case gs of + [g] -> do + WithSrcB _ g' <- casePat g + return g' + _ -> UPatProd . toNest <$> mapM pat gs + CLeaf CHole -> return $ UPatBinder (WithSrcB sid UIgnore) + CLeaf (CIdentifier name) -> return $ UPatBinder $ fromSourceNameW $ WithSrc sid name + CJuxtapose True lhs rhs -> do case lhs of - WithSrc _ (CBin (WithSrc _ JuxtaposeWithSpace) _ _) -> - throw SyntaxErr "Only unary constructors can form patterns without parens" + WithSrcs lhsId _ (CJuxtapose True _ _) -> throw lhsId OnlyUnaryWithoutParens _ -> return () name <- identifier "pattern constructor name" lhs arg <- pat rhs - return $ UPatCon (fromString name) (UnaryNest arg) - pat' (CBin (WithSrc _ JuxtaposeNoSpace) lhs rhs) = do + return $ UPatCon (fromSourceNameW name) (UnaryNest arg) + CJuxtapose False lhs rhs -> do name <- identifier "pattern constructor name" lhs case rhs of - WithSrc _ (CParens gs) -> UPatCon (fromString name) . toNest <$> mapM pat gs + WithSrcs _ _ (CParens gs) -> do + gs' <- mapM pat gs + return $ UPatCon (fromSourceNameW name) (toNest gs') _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - pat' _ = throw SyntaxErr "Illegal pattern" + _ -> throw sid IllegalPattern -data ParamStyle - = TypeParam -- binder optional, used in pi types - | DataParam -- type optional , used in lambda +tyOptBinder :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +tyOptBinder expl (WithSrcs sid sids grp) = case grp of + CBin (WithSrc _ Pipe) _ rhs -> throw (getSrcId rhs) UnexpectedConstraint + CBin (WithSrc _ Colon) name ty -> do + b <- uBinder name + ann <- UAnn <$> expr ty + return $ UAnnBinder expl b ann [] + g -> do + ty <- expr (WithSrcs sid sids g) + return $ UAnnBinder expl (WithSrcB sid UIgnore) (UAnn ty) [] -tyOptBinder :: Group -> SyntaxM (UAnnBinder req VoidS VoidS) -tyOptBinder = \case - Binary Pipe _ _ -> throw SyntaxErr "Unexpected constraint" - Binary Colon name ty -> do +binderOptTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +binderOptTy expl = \case + WithSrcs _ _ (CBin (WithSrc _ Colon) name ty) -> do b <- uBinder name ann <- UAnn <$> expr ty - return $ UAnnBinder b ann [] + return $ UAnnBinder expl b ann [] g -> do - ty <- expr g - return $ UAnnBinder UIgnore (UAnn ty) [] - -binderOptTy :: Group -> SyntaxM (UOptAnnBinder VoidS VoidS) -binderOptTy g = do - (g', constraints) <- trailingConstraints g - case g' of - Binary Colon name ty -> do - b <- uBinder name - ann <- UAnn <$> expr ty - return $ UAnnBinder b ann constraints - _ -> do - b <- uBinder g' - return $ UAnnBinder b UNoAnn constraints - -trailingConstraints :: Group -> SyntaxM (Group, [UConstraint VoidS]) -trailingConstraints gTop = go [] gTop where - go :: [UConstraint VoidS] -> Group -> SyntaxM (Group, [UConstraint VoidS]) - go cs = \case - Binary Pipe lhs c -> do - c' <- expr c - go (c':cs) lhs - g -> return (g, cs) - -argList :: [Group] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) + b <- uBinder g + return $ UAnnBinder expl b UNoAnn [] + +binderReqTy :: Explicitness -> GroupW -> SyntaxM (UAnnBinder VoidS VoidS) +binderReqTy expl (WithSrcs _ _ (CBin (WithSrc _ Colon) name ty)) = do + b <- uBinder name + ann <- UAnn <$> expr ty + return $ UAnnBinder expl b ann [] +binderReqTy _ g = throw (getSrcId g) ExpectedAnnBinder + +argList :: [GroupW] -> SyntaxM ([UExpr VoidS], [UNamedArg VoidS]) argList gs = partitionEithers <$> mapM singleArg gs -singleArg :: Group -> SyntaxM (Either (UExpr VoidS) (UNamedArg VoidS)) +singleArg :: GroupW -> SyntaxM (Either (UExpr VoidS) (UNamedArg VoidS)) singleArg = \case - WithSrc src (CBin (WithSrc _ CSEqual) lhs rhs) -> addSrcContext src $ Right <$> - ((,) <$> identifier "named argument" lhs <*> expr rhs) + WithSrcs _ _ (CBin (WithSrc _ CSEqual) lhs rhs) -> Right <$> + ((,) <$> withoutSrc <$> identifier "named argument" lhs <*> expr rhs) g -> Left <$> expr g -identifier :: String -> Group -> SyntaxM SourceName -identifier ctx = dropSrc identifier' where - identifier' (CIdentifier name) = return name - identifier' _ = throw SyntaxErr $ "Expected " ++ ctx ++ " to be an identifier" +identifier :: String -> GroupW -> SyntaxM SourceNameW +identifier ctx (WithSrcs sid _ g) = case g of + CLeaf (CIdentifier name) -> return $ WithSrc sid name + _ -> throw sid $ ExpectedIdentifier ctx -aEffects :: ([Group], Maybe Group) -> SyntaxM (UEffectRow VoidS) -aEffects (effs, optEffTail) = do +aEffects :: WithSrcs ([GroupW], Maybe GroupW) -> SyntaxM (UEffectRow VoidS) +aEffects (WithSrcs _ _ (effs, optEffTail)) = do lhs <- mapM effect effs rhs <- forM optEffTail \effTail -> - fromString <$> identifier "effect row remainder variable" effTail + fromSourceNameW <$> identifier "effect row remainder variable" effTail return $ UEffectRow (S.fromList lhs) rhs -effect :: Group -> SyntaxM (UEffect VoidS) -effect (WithSrc _ (CParens [g])) = effect g -effect (Binary JuxtaposeWithSpace (Identifier "Read") (Identifier h)) = - return $ URWSEffect Reader $ fromString h -effect (Binary JuxtaposeWithSpace (Identifier "Accum") (Identifier h)) = - return $ URWSEffect Writer $ fromString h -effect (Binary JuxtaposeWithSpace (Identifier "State") (Identifier h)) = - return $ URWSEffect State $ fromString h -effect (Identifier "Except") = return UExceptionEffect -effect (Identifier "IO") = return UIOEffect -effect _ = throw SyntaxErr "Unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, or the name of a user-defined effect." - -aMethod :: CSDecl -> SyntaxM (Maybe (UMethodDef VoidS)) -aMethod (WithSrc _ CPass) = return Nothing -aMethod (WithSrc src d) = Just . WithSrcE src <$> addSrcContext src case d of +effect :: GroupW -> SyntaxM (UEffect VoidS) +effect (WithSrcs grpSid _ grp) = case grp of + CParens [g] -> effect g + CJuxtapose True (Identifier "Read" ) (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect Reader $ fromSourceNameW (WithSrc sid h) + CJuxtapose True (Identifier "Accum") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect Writer $ fromSourceNameW (WithSrc sid h) + CJuxtapose True (Identifier "State") (WithSrcs sid _ (CLeaf (CIdentifier h))) -> + return $ URWSEffect State $ fromSourceNameW (WithSrc sid h) + CLeaf (CIdentifier "Except") -> return UExceptionEffect + CLeaf (CIdentifier "IO" ) -> return UIOEffect + _ -> throw grpSid UnexpectedEffectForm + +aMethod :: CSDeclW -> SyntaxM (Maybe (UMethodDef VoidS)) +aMethod (WithSrcs _ _ CPass) = return Nothing +aMethod (WithSrcs sid _ d) = Just . WithSrcE sid <$> case d of CDefDecl def -> do - (name, lam) <- aDef def - return $ UMethodDef (fromString name) lam - CLet (WithSrc _ (CIdentifier name)) rhs -> do - rhs' <- ULamExpr ([], Empty) ImplicitApp Nothing Nothing <$> block rhs - return $ UMethodDef (fromString name) rhs' - _ -> throw SyntaxErr "Unexpected method definition. Expected `def` or `x = ...`." + (WithSrc nameSid name, lam) <- aDef def + return $ UMethodDef (SourceName nameSid name) lam + CLet (WithSrcs lhsSid _ (CLeaf (CIdentifier name))) rhs -> do + rhs' <- ULamExpr Empty ImplicitApp Nothing Nothing <$> block rhs + return $ UMethodDef (fromSourceNameW (WithSrc lhsSid name)) rhs' + _ -> throw sid UnexpectedMethodDef asExpr :: UBlock VoidS -> UExpr VoidS asExpr (WithSrcE src b) = case b of @@ -366,22 +393,22 @@ asExpr (WithSrcE src b) = case b of _ -> WithSrcE src $ UDo $ WithSrcE src b block :: CSBlock -> SyntaxM (UBlock VoidS) -block (ExprBlock g) = WithSrcE emptySrcPosCtx . UBlock Empty <$> expr g -block (IndentedBlock decls) = do +block (ExprBlock g) = WithSrcE (srcPos g) . UBlock Empty <$> expr g +block (IndentedBlock sid decls) = do (decls', result) <- blockDecls decls - return $ WithSrcE emptySrcPosCtx $ UBlock decls' result + return $ WithSrcE sid $ UBlock decls' result -blockDecls :: [CSDecl] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) +blockDecls :: [CSDeclW] -> SyntaxM (Nest UDecl VoidS VoidS, UExpr VoidS) blockDecls [] = error "shouldn't have empty list of decls" -blockDecls [WithSrc src d] = addSrcContext src case d of +blockDecls [WithSrcs sid _ d] = case d of CExpr g -> (Empty,) <$> expr g - _ -> throw SyntaxErr "Block must end in expression" -blockDecls (WithSrc pos (CBind b rhs):ds) = do - (_, b') <- generalBinder DataParam Explicit b + _ -> throw sid BlockWithoutFinalExpr +blockDecls (WithSrcs sid _ (CBind b rhs):ds) = do + b' <- binderOptTy Explicit b rhs' <- asExpr <$> block rhs - body <- block $ IndentedBlock ds - let lam = ULam $ ULamExpr ([Explicit], UnaryNest b') ExplicitApp Nothing Nothing body - return (Empty, WithSrcE pos $ extendAppRight rhs' (ns lam)) + body <- block $ IndentedBlock sid ds -- Not really the right SrcId + let lam = ULam $ ULamExpr (UnaryNest b') ExplicitApp Nothing Nothing body + return (Empty, WithSrcE sid $ extendAppRight rhs' (WithSrcE sid lam)) blockDecls (d:ds) = do d' <- decl PlainLet d (ds', e) <- blockDecls ds @@ -389,86 +416,76 @@ blockDecls (d:ds) = do -- === Concrete to abstract syntax of expressions === -expr :: Group -> SyntaxM (UExpr VoidS) -expr = propagateSrcE expr' where - expr' CEmpty = return UHole - -- Binders (e.g., in pi types) should not hit this case - expr' (CIdentifier name) = return $ fromString name - expr' (CPrim prim xs) = UPrim prim <$> mapM expr xs - expr' (CNat word) = return $ UNatLit word - expr' (CInt int) = return $ UIntLit int - expr' (CString str) = return $ explicitApp (fromString "to_list") - [ns $ UTabCon $ map (ns . charExpr) str] - expr' (CChar char) = return $ charExpr char - expr' (CFloat num) = return $ UFloatLit num - expr' CHole = return UHole - expr' (CParens [g]) = dropSrcE <$> expr g - expr' (CParens gs) = UPrim UTuple <$> mapM expr gs +expr :: GroupW -> SyntaxM (UExpr VoidS) +expr (WithSrcs sid _ grp) = WithSrcE sid <$> case grp of + CLeaf x -> leaf sid x + CPrim prim xs -> UPrim prim <$> mapM expr xs + CParens [g] -> do + WithSrcE _ result <- expr g + return result + CParens gs -> UPrim UTuple <$> mapM expr gs -- Table constructors here. Other uses of square brackets -- should be detected upstream, before calling expr. - expr' (CBrackets gs) = UTabCon <$> mapM expr gs - expr' (CGivens _) = throw SyntaxErr $ "Unexpected `given` clause" - expr' (CArrow lhs effs rhs) = do + CBrackets gs -> UTabCon <$> mapM expr gs + CGivens _ -> throw sid UnexpectedGivenClause + CArrow lhs effs rhs -> do case lhs of - WithSrc _ (CParens gs) -> do - bs <- generalBinders TypeParam Explicit gs + WithSrcs _ _ (CParens gs) -> do + bs <- aPiBinders gs effs' <- fromMaybeM effs UPure aEffects resultTy <- expr rhs return $ UPi $ UPiExpr bs ExplicitApp effs' resultTy - _ -> throw SyntaxErr "Argument types should be in parentheses" - expr' (CDo b) = UDo <$> block b - -- Binders (e.g., in pi types) should not hit this case - expr' (CBin (WithSrc opSrc op) lhs rhs) = - case op of - JuxtaposeNoSpace -> do - f <- expr lhs - case rhs of - WithSrc _ (CParens args) -> do - (posArgs, namedArgs) <- argList args - return $ UApp f posArgs namedArgs - WithSrc _ (CBrackets args) -> do - args' <- mapM expr args - return $ UTabApp f args' - _ -> error "unexpected postfix group (should be ruled out at grouping stage)" - JuxtaposeWithSpace -> extendAppRight <$> expr lhs <*> expr rhs - Dollar -> extendAppRight <$> expr lhs <*> expr rhs - Pipe -> extendAppLeft <$> expr lhs <*> expr rhs - Dot -> do - lhs' <- expr lhs - WithSrc src rhs' <- return rhs - name <- addSrcContext src $ case rhs' of - CIdentifier name -> return $ FieldName name - CNat i -> return $ FieldNum $ fromIntegral i - _ -> throw SyntaxErr "Field must be a name or an integer" - return $ UFieldAccess lhs' (WithSrc src name) - DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs - EvalBinOp s -> evalOp s - DepAmpersand -> do - lhs' <- tyOptPat lhs - UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs - DepComma -> UDepPair <$> (expr lhs) <*> expr rhs - CSEqual -> throw SyntaxErr "Equal sign must be used as a separator for labels or binders, not a standalone operator" - Colon -> throw SyntaxErr "Colon separates binders from their type annotations, is not a standalone operator.\nIf you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" - ImplicitArrow -> case lhs of - WithSrc _ (CParens gs) -> do - bs <- generalBinders TypeParam Explicit gs - resultTy <- expr rhs - return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy - _ -> throw SyntaxErr "Argument types should be in parentheses" - FatArrow -> do - lhs' <- tyOptPat lhs - UTabPi . (UTabPiExpr lhs') <$> expr rhs - where - evalOp s = do - let f = WithSrcE opSrc (fromString s) - lhs' <- expr lhs - rhs' <- expr rhs - return $ explicitApp f [lhs', rhs'] - expr' (CPrefix name g) = + WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens + CDo b -> UDo <$> block b + CJuxtapose hasSpace lhs rhs -> case hasSpace of + True -> extendAppRight <$> expr lhs <*> expr rhs + False -> do + f <- expr lhs + case rhs of + WithSrcs _ _ (CParens args) -> do + (posArgs, namedArgs) <- argList args + return $ UApp f posArgs namedArgs + WithSrcs _ _ (CBrackets args) -> do + args' <- mapM expr args + return $ UTabApp f args' + _ -> error "unexpected postfix group (should be ruled out at grouping stage)" + CBin (WithSrc opSid op) lhs rhs -> case op of + Dollar -> extendAppRight <$> expr lhs <*> expr rhs + Pipe -> extendAppLeft <$> expr lhs <*> expr rhs + Dot -> do + lhs' <- expr lhs + WithSrcs rhsSid _ rhs' <- return rhs + name <- case rhs' of + CLeaf (CIdentifier name) -> return $ FieldName name + CLeaf (CNat i ) -> return $ FieldNum $ fromIntegral i + _ -> throw rhsSid BadField + return $ UFieldAccess lhs' (WithSrc rhsSid name) + DoubleColon -> UTypeAnn <$> (expr lhs) <*> expr rhs + EvalBinOp s -> evalOp s + DepAmpersand -> do + lhs' <- tyOptPat lhs + UDepPairTy . (UDepPairType ExplicitDepPair lhs') <$> expr rhs + DepComma -> UDepPair <$> (expr lhs) <*> expr rhs + CSEqual -> throw opSid BadEqualSign + Colon -> throw opSid BadColon + ImplicitArrow -> case lhs of + WithSrcs _ _ (CParens gs) -> do + bs <- aPiBinders gs + resultTy <- expr rhs + return $ UPi $ UPiExpr bs ImplicitApp UPure resultTy + WithSrcs lhsSid _ _ -> throw lhsSid ArgsShouldHaveParens + FatArrow -> do + lhs' <- tyOptPat lhs + UTabPi . (UTabPiExpr lhs') <$> expr rhs + where + evalOp s = do + let f = WithSrcE opSid (fromSourceNameW (WithSrc opSid s)) + lhs' <- expr lhs + rhs' <- expr rhs + return $ explicitApp f [lhs', rhs'] + CPrefix (WithSrc prefixSid name) g -> do case name of - ".." -> range "RangeTo" <$> expr g - "..<" -> range "RangeToExc" <$> expr g - "+" -> (dropSrcE <$> expr g) <&> \case + "+" -> (withoutSrc <$> expr g) <&> \case UNatLit i -> UIntLit (fromIntegral i) UIntLit i -> UIntLit i UFloatLit i -> UFloatLit i @@ -477,68 +494,72 @@ expr = propagateSrcE expr' where WithSrcE _ (UNatLit i) -> UIntLit (-(fromIntegral i)) WithSrcE _ (UIntLit i) -> UIntLit (-i) WithSrcE _ (UFloatLit i) -> UFloatLit (-i) - e -> unaryApp "neg" e - _ -> throw SyntaxErr $ "Prefix (" ++ name ++ ") not legal as a bare expression" - where - range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [ns UHole, lim] - expr' (CPostfix name g) = - case name of - ".." -> range "RangeFrom" <$> expr g - "<.." -> range "RangeFromExc" <$> expr g - _ -> throw SyntaxErr $ "Postfix (" ++ name ++ ") not legal as a bare expression" - where - range :: UExpr VoidS -> UExpr VoidS -> UExpr' VoidS - range rangeName lim = explicitApp rangeName [ns UHole, lim] - expr' (CLambda params body) = do - params' <- aExplicitParams $ map stripParens params + e -> unaryApp (mkUVar prefixSid "neg") e + _ -> throw prefixSid $ BadPrefix $ pprint name + CLambda params body -> do + params' <- explicitBindersOptAnn $ WithSrcs sid [] $ map stripParens params body' <- block body return $ ULam $ ULamExpr params' ExplicitApp Nothing Nothing body' - expr' (CFor kind indices body) = do + CFor kind indices body -> do let (dir, trailingUnit) = case kind of KFor -> (Fwd, False) KFor_ -> (Fwd, True) KRof -> (Rev, False) KRof_ -> (Rev, True) -- TODO: Can we fetch the source position from the error context, to feed into `buildFor`? - e <- buildFor (0, 0) dir <$> mapM binderOptTy indices <*> block body + e <- buildFor sid dir <$> mapM (binderOptTy Explicit) indices <*> block body if trailingUnit - then return $ UDo $ ns $ UBlock (UnaryNest (nsB $ UExprDecl e)) (ns unitExpr) - else return $ dropSrcE e - expr' (CCase scrut alts) = UCase <$> (expr scrut) <*> mapM alternative alts + then return $ UDo $ WithSrcE sid $ UBlock (UnaryNest (WithSrcB sid $ UExprDecl e)) (unitExpr sid) + else return $ withoutSrc e + CCase scrut alts -> UCase <$> (expr scrut) <*> mapM alternative alts where alternative (match, body) = UAlt <$> casePat match <*> block body - expr' (CIf p c a) = do + CIf p c a -> do p' <- expr p c' <- block c a' <- case a of - Nothing -> return $ ns $ UBlock Empty $ ns unitExpr + Nothing -> return $ WithSrcE sid $ UBlock Empty $ unitExpr sid (Just alternative) -> block alternative return $ UCase p' - [ UAlt (nsB $ UPatCon "True" Empty) c' - , UAlt (nsB $ UPatCon "False" Empty) a'] - expr' (CWith lhs rhs) = do + [ UAlt (WithSrcB sid $ UPatCon (SourceName sid "True") Empty) c' + , UAlt (WithSrcB sid $ UPatCon (SourceName sid "False") Empty) a'] + CWith lhs rhs -> do ty <- expr lhs case rhs of - [b] -> do - b' <- binderOptTy b + WithSrcs _ _ [b] -> do + b' <- binderReqTy Explicit b return $ UDepPairTy $ UDepPairType ImplicitDepPair b' ty _ -> error "n-ary dependent pairs not implemented" +leaf :: SrcId -> CLeaf -> SyntaxM (UExpr' VoidS) +leaf sid = \case + -- Binders (e.g., in pi types) should not hit this case + CIdentifier name -> return $ fromSourceNameW $ WithSrc sid name + CNat word -> return $ UNatLit word + CInt int -> return $ UIntLit int + CString str -> do + xs <- return $ map (WithSrcE sid . charExpr) str + let toListVar = mkUVar sid "to_list" + return $ explicitApp toListVar [WithSrcE sid (UTabCon xs)] + CChar char -> return $ charExpr char + CFloat num -> return $ UFloatLit num + CHole -> return UHole + charExpr :: Char -> (UExpr' VoidS) charExpr c = ULit $ Word8Lit $ fromIntegral $ fromEnum c -unitExpr :: UExpr' VoidS -unitExpr = UPrim (UCon $ P.ProdCon) [] +unitExpr :: SrcId -> UExpr VoidS +unitExpr sid = WithSrcE sid $ UPrim (UCon $ P.ProdCon) [] -- === Builders === -- TODO Does this generalize? Swap list for Nest? -buildFor :: SrcPos -> Direction -> [UOptAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS -buildFor pos dir binders body = case binders of +-- TODO: these SrcIds aren't really correct +buildFor :: SrcId -> Direction -> [UAnnBinder VoidS VoidS] -> UBlock VoidS -> UExpr VoidS +buildFor sid dir binders body = case binders of [] -> error "should have nonempty list of binder" - [b] -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b body - b:bs -> WithSrcE (fromPos pos) $ UFor dir $ UForExpr b $ - ns $ UBlock Empty $ buildFor pos dir bs body + [b] -> WithSrcE sid $ UFor dir $ UForExpr b body + b:bs -> WithSrcE sid $ UFor dir $ UForExpr b $ + WithSrcE sid $ UBlock Empty $ buildFor sid dir bs body -- === Helpers === @@ -556,26 +577,5 @@ unaryApp f x = UApp f [x] [] explicitApp :: UExpr n -> [UExpr n] -> UExpr' n explicitApp f xs = UApp f xs [] -ns :: (a n) -> WithSrcE a n -ns = WithSrcE emptySrcPosCtx - -nsB :: (b n l) -> WithSrcB b n l -nsB = WithSrcB emptySrcPosCtx - toNest :: [a VoidS VoidS] -> Nest a VoidS VoidS toNest = foldr Nest Empty - -dropSrc :: (t -> SyntaxM a) -> WithSrc t -> SyntaxM a -dropSrc act (WithSrc src x) = addSrcContext src $ act x - -propagateSrcE :: (t -> SyntaxM (e n)) -> WithSrc t -> SyntaxM (WithSrcE e n) -propagateSrcE act (WithSrc src x) = addSrcContext src (WithSrcE src <$> act x) - -dropSrcE :: WithSrcE e n -> e n -dropSrcE (WithSrcE _ x) = x - -propagateSrcB :: (t -> SyntaxM (binder n l)) -> WithSrc t -> SyntaxM (WithSrcB binder n l) -propagateSrcB act (WithSrc src x) = addSrcContext src (WithSrcB src <$> act x) - -dropSrcB :: WithSrcB binder n l -> binder n l -dropSrcB (WithSrcB _ x) = x diff --git a/src/lib/Actor.hs b/src/lib/Actor.hs index 3fb452c06..59ff089a5 100644 --- a/src/lib/Actor.hs +++ b/src/lib/Actor.hs @@ -1,70 +1,259 @@ --- Copyright 2022 Google LLC +-- Copyright 2023 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Actor (PChan, sendPChan, sendOnly, subChan, - Actor, runActor, spawn, - LogServerMsg (..), logServer) where +{-# LANGUAGE UndecidableInstances #-} -import Control.Concurrent (Chan, forkIO, newChan, readChan, ThreadId, writeChan) +module Actor ( + ActorM, Actor (..), launchActor, send, selfMailbox, messageLoop, + sliceMailbox, SubscribeMsg (..), IncServer, IncServerT, FileWatcher, + StateServer, flushDiffs, handleSubscribeMsg, subscribe, subscribeIO, sendSync, + runIncServerT, launchFileWatcher, Mailbox, launchIncFunctionEvaluator + ) where + +import Control.Concurrent +import Control.Monad import Control.Monad.State.Strict +import Control.Monad.Reader +import qualified Data.ByteString as BS +import Data.IORef +import Data.Text.Encoding qualified as T +import Data.Text (Text) +import System.Directory (getModificationTime) +import GHC.Generics + +import IncState +import MonadUtil -import Util (onFst, onSnd) +-- === Actor implementation === --- Micro-actors. +newtype ActorM msg a = ActorM { runActorM :: ReaderT (Chan msg) IO a } + deriving (Functor, Applicative, Monad, MonadIO) --- In this model, an "actor" is just an IO computation (presumably --- running on its own Haskell thread) that receives messages on a --- Control.Concurrent.Chan channel. The idea is that the actor thread --- only receives information (or synchronization) from other threads --- through messages sent on that one channel, and no other thread --- reads messages from that channel. +newtype Mailbox a = Mailbox { sendToMailbox :: a -> IO () } --- We start the actor with a two-way view of its input channel so it --- can subscribe itself to message streams by passing (a send-only --- view of) it to another actor. -type Actor a = Chan a -> IO () +class (Show msg, MonadIO m) => Actor msg m | m -> msg where + selfChan :: m (Chan msg) --- We also define a send-only channel type, to help ourselves not --- accidentally read from channels we aren't supposed to. -newtype PChan a = PChan { sendPChan :: a -> IO () } +instance Show msg => Actor msg (ActorM msg) where + selfChan = ActorM ask -sendOnly :: Chan a -> PChan a -sendOnly chan = PChan $ \ !x -> writeChan chan x +instance Actor msg m => Actor msg (ReaderT r m) where selfChan = lift $ selfChan +instance Actor msg m => Actor msg (StateT s m) where selfChan = lift $ selfChan -subChan :: (a -> b) -> PChan b -> PChan a -subChan f chan = PChan (sendPChan chan . f) +send :: MonadIO m => Mailbox msg -> msg -> m () +send chan msg = liftIO $ sendToMailbox chan msg --- Synchronously execute an actor. -runActor :: Actor a -> IO () -runActor actor = newChan >>= actor +selfMailbox :: Actor msg m => (a -> msg) -> m (Mailbox a) +selfMailbox asSelfMessage = do + chan <- selfChan + return $ Mailbox \msg -> writeChan chan (asSelfMessage msg) --- Asynchronously launch an actor. Immediately returns permission to --- kill that actor and to send it messages. -spawn :: Actor a -> IO (ThreadId, PChan a) -spawn actor = do +launchActor :: MonadIO m => ActorM msg () -> m (Mailbox msg) +launchActor m = liftIO do chan <- newChan - tid <- forkIO $ actor chan - return (tid, sendOnly chan) - --- A log server. Combines inputs monoidally and pushes incremental --- updates to subscribers. - -data LogServerMsg a = Subscribe (PChan a) - | Publish a - -logServer :: Monoid a => Actor (LogServerMsg a) -logServer self = flip evalStateT (mempty, []) $ forever $ do - msg <- liftIO $ readChan self - case msg of - Subscribe chan -> do - curVal <- gets fst - liftIO $ chan `sendPChan` curVal - modify $ onSnd (chan:) - Publish x -> do - modify $ onFst (<> x) - subscribers <- gets snd - mapM_ (liftIO . (`sendPChan` x)) subscribers + void $ forkIO $ runReaderT (runActorM m) chan + return $ Mailbox \msg -> writeChan chan msg + +messageLoop :: Actor msg m => (msg -> m ()) -> m () +messageLoop handleMessage = do + forever do + msg <- liftIO . readChan =<< selfChan + handleMessage msg + +sliceMailbox :: (b -> a) -> Mailbox a -> Mailbox b +sliceMailbox f (Mailbox sendMsg) = Mailbox $ sendMsg . f + +-- === Promises === + +newtype Promise a = Promise (MVar a) +newtype PromiseSetter a = PromiseSetter (MVar a) + +newPromise :: MonadIO m => m (Promise a, PromiseSetter a) +newPromise = do + v <- liftIO $ newEmptyMVar + return (Promise v, PromiseSetter v) + +waitForPromise :: MonadIO m => Promise a -> m a +waitForPromise (Promise v) = liftIO $ readMVar v + +setPromise :: MonadIO m => PromiseSetter a -> a -> m () +setPromise (PromiseSetter v) x = liftIO $ putMVar v x + +-- Message that expects a synchronous reponse +data SyncMsg msg response = SyncMsg msg (PromiseSetter response) + +sendSync :: MonadIO m => Mailbox (SyncMsg msg response) -> msg -> m response +sendSync mailbox msg = do + (result, resultSetter) <- newPromise + send mailbox (SyncMsg msg resultSetter) + waitForPromise result + + +-- === Diff server === + +data IncServerState s d = IncServerState + { subscribers :: [Mailbox d] + , bufferedUpdates :: d + , curIncState :: s } + deriving (Show, Generic) + +class (Monoid d, MonadIO m) => IncServer s d m | m -> s, m -> d where + getIncServerStateRef :: m (IORef (IncServerState s d)) + +data SubscribeMsg s d = Subscribe (SyncMsg (Mailbox d) s) deriving (Show) + +getIncServerState :: IncServer s d m => m (IncServerState s d) +getIncServerState = readRef =<< getIncServerStateRef + +updateIncServerState :: IncServer s d m => (IncServerState s d -> IncServerState s d) -> m () +updateIncServerState f = do + ref <- getIncServerStateRef + prev <- readRef ref + writeRef ref $ f prev + +handleSubscribeMsg :: IncServer s d m => SubscribeMsg s d -> m () +handleSubscribeMsg (Subscribe (SyncMsg newSub response)) = do + flushDiffs + updateIncServerState \s -> s { subscribers = newSub : subscribers s } + curState <- curIncState <$> getIncServerState + setPromise response curState + +flushDiffs :: IncServer s d m => m () +flushDiffs = do + d <- bufferedUpdates <$> getIncServerState + updateIncServerState \s -> s { bufferedUpdates = mempty } + subs <- subscribers <$> getIncServerState + -- TODO: consider testing for emptiness here + forM_ subs \sub -> send sub d + +type StateServer s d = Mailbox (SubscribeMsg s d) + +subscribe :: Actor msg m => (d -> msg) -> StateServer s d -> m s +subscribe inject server = do + updateChannel <- selfMailbox inject + sendSync (sliceMailbox Subscribe server) updateChannel + +subscribeIO :: StateServer s d -> IO (s, Chan d) +subscribeIO server = do + chan <- newChan + let mailbox = Mailbox (writeChan chan) + s <- sendSync (sliceMailbox Subscribe server) mailbox + return (s, chan) + +newtype IncServerT s d m a = IncServerT { runIncServerT' :: ReaderT (Ref (IncServerState s d)) m a } + deriving (Functor, Applicative, Monad, MonadIO, Actor msg, FreshNames name, MonadTrans) + +instance (MonadIO m, IncState s d) => IncServer s d (IncServerT s d m) where + getIncServerStateRef = IncServerT ask + +instance (MonadIO m, IncState s d) => DefuncState d (IncServerT s d m) where + update d = updateIncServerState \s -> s + { bufferedUpdates = bufferedUpdates s <> d + , curIncState = curIncState s `applyDiff` d} + +instance (MonadIO m, IncState s d) => LabelReader (SingletonLabel s) (IncServerT s d m) where + getl It = curIncState <$> getIncServerState + +runIncServerT :: (MonadIO m, IncState s d) => s -> IncServerT s d m a -> m a +runIncServerT s cont = do + ref <- newRef $ IncServerState [] mempty s + runReaderT (runIncServerT' cont) ref + +-- === Incremental function server === + +-- If you just need something that computes a function incrementally and doesn't +-- need to maintain any other state then this will do. + +data IncFunctionEvaluatorMsg da b db = + Subscribe_IFEM (SubscribeMsg b db) + | Update_IFEM da + deriving (Show) + +launchIncFunctionEvaluator + :: (IncState b db, Show da, MonadIO m) + => StateServer a da + -> (a -> (b,s)) + -> (b -> s -> da -> (db, s)) + -> m (StateServer b db) +launchIncFunctionEvaluator server fInit fUpdate = + sliceMailbox Subscribe_IFEM <$> launchActor do + x0 <- subscribe Update_IFEM server + let (y0, s0) = fInit x0 + flip evalStateT s0 $ runIncServerT y0 $ messageLoop \case + Subscribe_IFEM msg -> handleSubscribeMsg msg + Update_IFEM dx -> do + y <- getl It + s <- lift get + let (dy, s') = fUpdate y s dx + lift $ put s' + update dy + flushDiffs + +-- === Refs === +-- Just a wrapper around IORef lifted to `MonadIO` + +type Ref = IORef + +newRef :: MonadIO m => a -> m (Ref a) +newRef = liftIO . newIORef + +readRef :: MonadIO m => Ref a -> m a +readRef = liftIO . readIORef + +writeRef :: MonadIO m => Ref a -> a -> m () +writeRef ref val = liftIO $ writeIORef ref val + +-- === Clock === + +-- Provides a periodic clock signal. The time interval is in microseconds. +launchClock :: MonadIO m => Int -> Mailbox () -> m () +launchClock intervalMicroseconds mailbox = + liftIO $ void $ forkIO $ forever do + threadDelay intervalMicroseconds + send mailbox () + +-- === File watcher === + +type SourceFileContents = Text +type FileWatcher = StateServer (Overwritable SourceFileContents) (Overwrite SourceFileContents) + +readFileContents :: MonadIO m => FilePath -> m Text +readFileContents path = liftIO $ T.decodeUtf8 <$> BS.readFile path + +data FileWatcherMsg = + ClockSignal_FW () + | Subscribe_FW (SubscribeMsg (Overwritable Text) (Overwrite Text)) + deriving (Show) + +launchFileWatcher :: MonadIO m => FilePath -> m FileWatcher +launchFileWatcher path = sliceMailbox Subscribe_FW <$> launchActor (fileWatcherImpl path) + +fileWatcherImpl :: FilePath -> ActorM FileWatcherMsg () +fileWatcherImpl path = do + initContents <- readFileContents path + t0 <- liftIO $ getModificationTime path + launchClock 100000 =<< selfMailbox ClockSignal_FW + modTimeRef <- newRef t0 + runIncServerT (Overwritable initContents) $ messageLoop \case + Subscribe_FW msg -> handleSubscribeMsg msg + ClockSignal_FW () -> do + tOld <- readRef modTimeRef + tNew <- liftIO $ getModificationTime path + when (tNew /= tOld) do + newContents <- readFileContents path + update $ OverwriteWith newContents + flushDiffs + writeRef modTimeRef tNew + +-- === instances === + +instance Show msg => Show (SyncMsg msg response) where + show (SyncMsg msg _) = show msg + +instance Show (Mailbox a) where + show _ = "mailbox" +deriving instance Actor msg m => Actor msg (FreshNameT m) diff --git a/src/lib/Algebra.hs b/src/lib/Algebra.hs index 65491714e..1175d1523 100644 --- a/src/lib/Algebra.hs +++ b/src/lib/Algebra.hs @@ -18,14 +18,16 @@ import Data.Text.Prettyprint.Doc import Data.List (intersperse) import Data.Tuple (swap) -import Builder hiding (sub, add, mul) +import Builder import Core +import CheapReduction import Err import IRVariants import MTL1 import Name import Subst import QueryType +import PPrint import Types.Core import Types.Imp import Types.Primitives @@ -48,17 +50,17 @@ newtype Polynomial (n::S) = -- This is the main entrypoint. Doing polynomial math sometimes lets -- us compute sums in closed form. This tries to compute -- `\sum_{i=0}^(lim-1) body`. `i`, `lim`, and `body` should all have type `Nat`. -sumUsingPolys :: (Builder SimpIR m, Fallible1 m, Emits n) - => Atom SimpIR n -> Abs (Binder SimpIR) (Block SimpIR) n -> m n (Atom SimpIR n) +sumUsingPolys :: Emits n + => SAtom n -> Abs (Binder SimpIR) (Expr SimpIR) n -> BuilderM SimpIR n (SAtom n) sumUsingPolys lim (Abs i body) = do sumAbs <- refreshAbs (Abs i body) \(i':>_) body' -> do - blockAsPoly body' >>= \case + exprAsPoly body' >>= \case Just poly' -> return $ Abs i' poly' - Nothing -> throw NotImplementedErr $ + Nothing -> throwInternal $ "Algebraic simplification failed to model index computations:\n" ++ "Trying to sum from 0 to " ++ pprint lim ++ " - 1, \\" ++ pprint i' ++ "." ++ pprint body' - limName <- emit (Atom lim) + limName <- emitToVar (Atom lim) emitPolynomial $ sum (LeftE (atomVarName limName)) sumAbs mul :: Polynomial n-> Polynomial n -> Polynomial n @@ -134,56 +136,53 @@ instance FromName PolySubstVal where fromName = PolyRename type BlockTraverserM i o a = SubstReaderT PolySubstVal (MaybeT1 (BuilderM SimpIR)) i o a -blockAsPoly - :: (EnvExtender m, EnvReader m) - => Block SimpIR n -> m n (Maybe (Polynomial n)) -blockAsPoly (Abs decls result) = - liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ blockAsPolyRec decls result - -blockAsPolyRec :: Nest (Decl SimpIR) i i' -> Atom SimpIR i' -> BlockTraverserM i o (Polynomial o) -blockAsPolyRec decls result = case decls of - Empty -> atomAsPoly result - Nest (Let b (DeclBinding _ expr)) restDecls -> do - p <- optional (exprAsPoly expr) - extendSubst (b@>PolySubstVal p) $ blockAsPolyRec restDecls result - - where - atomAsPoly :: Atom SimpIR i -> BlockTraverserM i o (Polynomial o) - atomAsPoly = \case - Var v -> atomVarAsPoly v - RepValAtom (RepVal _ (Leaf (IVar v' _))) -> impNameAsPoly v' - IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] +exprAsPoly :: (EnvExtender m, EnvReader m) => SExpr n -> m n (Maybe (Polynomial n)) +exprAsPoly expr = liftBuilder $ runMaybeT1 $ runSubstReaderT idSubst $ exprAsPolyRec expr + +atomAsPoly :: SAtom i -> BlockTraverserM i o (Polynomial o) +atomAsPoly = \case + Stuck _ (Var v) -> atomVarAsPoly v + Stuck _ (RepValAtom (RepVal _ (Leaf (IVar v' _)))) -> impNameAsPoly v' + IdxRepVal i -> return $ poly [((fromIntegral i) % 1, mono [])] + _ -> empty + +impNameAsPoly :: ImpName i -> BlockTraverserM i o (Polynomial o) +impNameAsPoly v = getSubst <&> (!v) >>= \case + PolyRename v' -> return $ poly [(1, mono [(RightE v', 1)])] + +atomVarAsPoly :: AtomVar SimpIR i -> BlockTraverserM i o (Polynomial o) +atomVarAsPoly v = getSubst <&> (! atomVarName v) >>= \case + PolySubstVal Nothing -> empty + PolySubstVal (Just cp) -> return cp + PolyRename v' -> do + v'' <- toAtomVar v' + case getType v'' of + IdxRepTy -> return $ poly [(1, mono [(LeftE v', 1)])] _ -> empty - impNameAsPoly :: ImpName i -> BlockTraverserM i o (Polynomial o) - impNameAsPoly v = getSubst <&> (!v) >>= \case - PolyRename v' -> return $ poly [(1, mono [(RightE v', 1)])] - - atomVarAsPoly :: AtomVar SimpIR i -> BlockTraverserM i o (Polynomial o) - atomVarAsPoly v = getSubst <&> (! atomVarName v) >>= \case - PolySubstVal Nothing -> empty - PolySubstVal (Just cp) -> return cp - PolyRename v' -> do - v'' <- toAtomVar v' - case getType v'' of - IdxRepTy -> return $ poly [(1, mono [(LeftE v', 1)])] - _ -> empty - - exprAsPoly :: Expr SimpIR i -> BlockTraverserM i o (Polynomial o) - exprAsPoly e = case e of - Atom a -> atomAsPoly a - PrimOp (BinOp op x y) -> case op of - IAdd -> add <$> atomAsPoly x <*> atomAsPoly y - IMul -> mul <$> atomAsPoly x <*> atomAsPoly y - -- XXX: we rely on the wrapping behavior of subtraction on unsigned ints - -- so that the distributive law holds, `a * (b - c) == (a * b) - (a * c)` - ISub -> sub <$> atomAsPoly x <*> atomAsPoly y - -- This is to handle `idiv` generated by `emitPolynomial` - IDiv -> case y of - IdxRepVal n -> mulConst (1 / fromIntegral n) <$> atomAsPoly x - _ -> empty - _ -> empty +exprAsPolyRec :: Expr SimpIR i -> BlockTraverserM i o (Polynomial o) +exprAsPolyRec e = case e of + Block _ block -> blockAsPoly block + Atom a -> atomAsPoly a + PrimOp (BinOp op x y) -> case op of + IAdd -> add <$> atomAsPoly x <*> atomAsPoly y + IMul -> mul <$> atomAsPoly x <*> atomAsPoly y + -- XXX: we rely on the wrapping behavior of subtraction on unsigned ints + -- so that the distributive law holds, `a * (b - c) == (a * b) - (a * c)` + ISub -> sub <$> atomAsPoly x <*> atomAsPoly y + -- This is to handle `idiv` generated by `emitPolynomial` + IDiv -> case y of + IdxRepVal n -> mulConst (1 / fromIntegral n) <$> atomAsPoly x _ -> empty + _ -> empty + _ -> empty + +blockAsPoly :: SBlock i -> BlockTraverserM i o (Polynomial o) +blockAsPoly (Abs decls result) = case decls of + Empty -> exprAsPolyRec result + Nest (Let b (DeclBinding _ expr)) restDecls -> do + p <- optional (exprAsPolyRec expr) + extendSubst (b@>PolySubstVal p) $ blockAsPoly $ Abs restDecls result -- === polynomials to Core expressions === @@ -192,7 +191,7 @@ blockAsPolyRec decls result = case decls of -- coefficients. This is why we have to find the least common multiples and do the -- accumulation over numbers multiplied by that LCM. We essentially do fixed point -- fractional math here. -emitPolynomial :: (Emits n, Builder SimpIR m) => Polynomial n -> m n (Atom SimpIR n) +emitPolynomial :: Emits n => Polynomial n -> BuilderM SimpIR n (SAtom n) emitPolynomial (Polynomial p) = do let constLCM = asAtom $ foldl lcm 1 $ fmap (denominator . snd) $ toList p monoAtoms <- flip traverse (toList p) $ \(m, c) -> do @@ -206,20 +205,23 @@ emitPolynomial (Polynomial p) = do -- because it might be causing overflows due to all arithmetic being shifted. asAtom = IdxRepVal . fromInteger -emitMonomial :: (Emits n, Builder SimpIR m) => Monomial n -> m n (Atom SimpIR n) +emitMonomial :: Emits n => Monomial n -> BuilderM SimpIR n (SAtom n) emitMonomial (Monomial m) = do varAtoms <- forM (toList m) \(v, e) -> case v of LeftE v' -> do - v'' <- Var <$> toAtomVar v' + v'' <- toAtom <$> toAtomVar v' ipow v'' e RightE v' -> do - let atom = RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy)) + atom <- mkStuck $ RepValAtom $ RepVal IdxRepTy (Leaf (IVar v' IIdxRepTy)) ipow atom e foldM imul (IdxRepVal 1) varAtoms -ipow :: (Emits n, Builder SimpIR m) => Atom SimpIR n -> Int -> m n (Atom SimpIR n) +ipow :: Emits n => SAtom n -> Int -> BuilderM SimpIR n (SAtom n) ipow x i = foldM imul (IdxRepVal 1) (replicate i x) +idiv :: Emits n => SAtom n -> SAtom n -> BuilderM SimpIR n (SAtom n) +idiv = undefined + -- === instances === instance GenericE Monomial where diff --git a/src/lib/Builder.hs b/src/lib/Builder.hs index c539d01b4..f3f790f00 100644 --- a/src/lib/Builder.hs +++ b/src/lib/Builder.hs @@ -17,7 +17,6 @@ import Control.Monad.State.Strict (MonadState (..), StateT (..), runStateT) import qualified Data.Map.Strict as M import Data.Foldable (fold) import Data.Graph (graphFromEdges, topSort) -import Data.Text.Prettyprint.Doc (Pretty (..)) import Foreign.Ptr import qualified Unsafe.Coerce as TrulyUnsafe @@ -29,20 +28,23 @@ import IRVariants import MTL1 import Subst import Name +import PeepholeOptimize +import PPrint import QueryType import Types.Core import Types.Imp import Types.Primitives import Types.Source -import Util (enumerate, transitiveClosureM, bindM2, toSnocList, (...)) +import Types.Top +import Util (enumerate, transitiveClosureM, bindM2, toSnocList) -- === Ordinary (local) builder class === -class (EnvReader m, EnvExtender m, Fallible1 m, IRRep r) +class (EnvReader m, Fallible1 m, IRRep r) => Builder (r::IR) (m::MonadKind1) | m -> r where rawEmitDecl :: Emits n => NameHint -> LetAnn -> Expr r n -> m n (AtomVar r n) -class Builder r m => ScopableBuilder (r::IR) (m::MonadKind1) | m -> r where +class (EnvExtender m, Builder r m) => ScopableBuilder (r::IR) (m::MonadKind1) | m -> r where buildScopedAndThen :: SinkableE e => (forall l. (Emits l, DExt n l) => m l (e l)) @@ -62,57 +64,35 @@ type Builder2 (r::IR) (m :: MonadKind2) = forall i. Builder r (m type ScopableBuilder2 (r::IR) (m :: MonadKind2) = forall i. ScopableBuilder r (m i) emitDecl :: (Builder r m, Emits n) => NameHint -> LetAnn -> Expr r n -> m n (AtomVar r n) -emitDecl _ _ (Atom (Var n)) = return n +emitDecl _ _ (Atom (Stuck _ (Var n))) = return n emitDecl hint ann expr = rawEmitDecl hint ann expr {-# INLINE emitDecl #-} -emit :: (Builder r m, Emits n) => Expr r n -> m n (AtomVar r n) -emit expr = emitDecl noHint PlainLet expr +emit :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emit e = case toExpr e of + Atom x -> return x + Block _ block -> emitDecls block >>= emit + expr -> do + v <- emitDecl noHint PlainLet $ peepholeExpr expr + return $ toAtom v {-# INLINE emit #-} -emitHinted :: (Builder r m, Emits n) => NameHint -> Expr r n -> m n (AtomVar r n) -emitHinted hint expr = emitDecl hint PlainLet expr -{-# INLINE emitHinted #-} - -emitOp :: (Builder r m, IsPrimOp e, Emits n) => e r n -> m n (Atom r n) -emitOp op = Var <$> emit (PrimOp $ toPrimOp op) -{-# INLINE emitOp #-} - -emitExpr :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) -emitExpr expr = Var <$> emit expr -{-# INLINE emitExpr #-} - -emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) -emitHof hof = mkTypedHof hof >>= emitOp - -mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) -mkTypedHof hof = do - effTy <- effTyOfHof hof - return $ TypedHof effTy hof - -emitUnOp :: (Builder r m, Emits n) => UnOp -> Atom r n -> m n (Atom r n) -emitUnOp op x = emitOp $ UnOp op x -{-# INLINE emitUnOp #-} - -emitBlock :: (Builder r m, Emits n) => Block r n -> m n (Atom r n) -emitBlock = emitDecls +emitToVar :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (AtomVar r n) +emitToVar expr = emit expr >>= \case + Stuck _ (Var v) -> return v + atom -> emitDecl noHint PlainLet (toExpr atom) +{-# INLINE emitToVar #-} emitDecls :: (Builder r m, Emits n, RenameE e, SinkableE e) => WithDecls r e n -> m n (e n) -emitDecls (Abs decls result) = runSubstReaderT idSubst $ emitDecls' decls result - -emitDecls' :: (Builder r m, Emits o, RenameE e, SinkableE e) - => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) -emitDecls' Empty e = renameM e -emitDecls' (Nest (Let b (DeclBinding ann expr)) rest) e = do - expr' <- renameM expr - AtomVar v _ <- emitDecl (getNameHint b) ann expr' - extendSubst (b @> v) $ emitDecls' rest e - -emitExprToAtom :: (Builder r m, Emits n) => Expr r n -> m n (Atom r n) -emitExprToAtom (Atom atom) = return atom -emitExprToAtom expr = Var <$> emit expr -{-# INLINE emitExprToAtom #-} +emitDecls (Abs decls result) = runSubstReaderT idSubst $ go decls result where + go :: (Builder r m, Emits o, RenameE e, SinkableE e) + => Nest (Decl r) i i' -> e i' -> SubstReaderT Name m i o (e o) + go Empty e = renameM e + go (Nest (Let b (DeclBinding ann expr)) rest) e = do + expr' <- renameM expr + AtomVar v _ <- emitDecl (getNameHint b) ann expr' + extendSubst (b @> v) $ go rest e buildScopedAssumeNoDecls :: (SinkableE e, ScopableBuilder r m) => (forall l. (Emits l, DExt n l) => m l (e l)) @@ -153,7 +133,7 @@ liftTopBuilderAndEmit cont = do newtype DoubleBuilderT (r::IR) (topEmissions::B) (m::MonadKind) (n::S) (a:: *) = DoubleBuilderT { runDoubleBuilderT' :: DoubleInplaceT Env topEmissions (BuilderEmissions r) m n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, MonadIO, Catchable, MonadReader r') + , MonadIO, Catchable, MonadReader r') deriving instance (ExtOutMap Env frag, HoistableB frag, OutFrag frag, Fallible m, IRRep r) => ScopeReader (DoubleBuilderT r frag m) @@ -252,9 +232,9 @@ instance ( IRRep r, RenameB frag, HoistableB frag, OutFrag frag {-# INLINE refreshAbs #-} instance (SinkableV v, HoistingTopBuilder f m) => HoistingTopBuilder f (SubstReaderT v m i) where - emitHoistedEnv ab = SubstReaderT $ lift $ emitHoistedEnv ab + emitHoistedEnv ab = liftSubstReaderT $ emitHoistedEnv ab {-# INLINE emitHoistedEnv #-} - canHoistToTop e = SubstReaderT $ lift $ canHoistToTop e + canHoistToTop e = liftSubstReaderT $ canHoistToTop e {-# INLINE canHoistToTop #-} -- === Top-level builder class === @@ -290,20 +270,13 @@ emitTopLet hint letAnn expr = do v <- emitBinding hint $ AtomNameBinding $ LetBound (DeclBinding letAnn expr) return $ AtomVar v ty -emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> STopLam n -> m n (TopFunName n) +emitTopFunBinding :: (Mut n, TopBuilder m) => NameHint -> TopFunDef n -> TopLam SimpIR n -> m n (TopFunName n) emitTopFunBinding hint def f = do emitBinding hint $ TopFunBinding $ DexTopFun def f Waiting emitSourceMap :: TopBuilder m => SourceMap n -> m n () emitSourceMap sm = emitLocalModuleEnv $ mempty {envSourceMap = sm} -emitSynthCandidates :: TopBuilder m => SynthCandidates n -> m n () -emitSynthCandidates sc = emitLocalModuleEnv $ mempty {envSynthCandidates = sc} - -addInstanceSynthCandidate :: TopBuilder m => ClassName n -> InstanceName n -> m n () -addInstanceSynthCandidate className instanceName = - emitSynthCandidates $ SynthCandidates [] (M.singleton className [instanceName]) - updateTransposeRelation :: (Mut n, TopBuilder m) => TopFunName n -> TopFunName n -> m n () updateTransposeRelation f1 f2 = updateTopEnv $ ExtendCache $ mempty { transpositionCache = eMapSingleton f1 f2 <> eMapSingleton f2 f1} @@ -365,7 +338,7 @@ getCache = withEnv $ envCache . topEnv newtype TopBuilderT (m::MonadKind) (n::S) (a:: *) = TopBuilderT { runTopBuilderT' :: InplaceT Env TopEnvFrag m n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, ScopeReader, MonadTrans1, MonadReader r + , ScopeReader, MonadTrans1, MonadReader r , MonadWriter w, MonadState s, MonadIO, Catchable) type TopBuilderM = TopBuilderT HardFailM @@ -401,13 +374,13 @@ instance Fallible m => TopBuilder (TopBuilderT m) where {-# INLINE localTopBuilder #-} instance (SinkableV v, TopBuilder m) => TopBuilder (SubstReaderT v m i) where - emitBinding hint binding = SubstReaderT $ lift $ emitBinding hint binding + emitBinding hint binding = liftSubstReaderT $ emitBinding hint binding {-# INLINE emitBinding #-} - emitEnv ab = SubstReaderT $ lift $ emitEnv ab + emitEnv ab = liftSubstReaderT $ emitEnv ab {-# INLINE emitEnv #-} - emitNamelessEnv bs = SubstReaderT $ lift $ emitNamelessEnv bs + emitNamelessEnv bs = liftSubstReaderT $ emitNamelessEnv bs {-# INLINE emitNamelessEnv #-} - localTopBuilder cont = SubstReaderT $ ReaderT \env -> do + localTopBuilder cont = SubstReaderT \env -> do localTopBuilder do Distinct <- getDistinct runReaderT (runSubstReaderT' cont) (sink env) @@ -440,7 +413,7 @@ type BuilderEmissions r = RNest (Decl r) newtype BuilderT (r::IR) (m::MonadKind) (n::S) (a:: *) = BuilderT { runBuilderT' :: InplaceT Env (BuilderEmissions r) m n a } deriving ( Functor, Applicative, Monad, MonadTrans1, MonadFail, Fallible - , Catchable, CtxReader, ScopeReader, Alternative, Searcher + , Catchable, ScopeReader, Alternative , MonadWriter w, MonadReader r') type BuilderM (r::IR) = BuilderT r HardFailM @@ -502,6 +475,8 @@ instance (IRRep r, Fallible m) => Builder r (BuilderT r m) where ty <- return $ getType expr v <- BuilderT $ freshExtendSubInplaceT hint \b -> (BuilderDeclEmission $ Let b $ DeclBinding ann expr, binderName b) + -- -- Debugging snippet + -- traceM $ pprint v ++ " = " ++ pprint expr return $ AtomVar v ty {-# INLINE rawEmitDecl #-} @@ -514,14 +489,14 @@ instance (IRRep r, Fallible m) => EnvExtender (BuilderT r m) where {-# INLINE refreshAbs #-} instance (SinkableV v, ScopableBuilder r m) => ScopableBuilder r (SubstReaderT v m i) where - buildScopedAndThen cont1 cont2 = SubstReaderT $ ReaderT \env -> + buildScopedAndThen cont1 cont2 = SubstReaderT \env -> buildScopedAndThen (runReaderT (runSubstReaderT' cont1) (sink env)) (\d e -> runReaderT (runSubstReaderT' $ cont2 d e) (sink env)) {-# INLINE buildScopedAndThen #-} instance (SinkableV v, Builder r m) => Builder r (SubstReaderT v m i) where - rawEmitDecl hint ann expr = SubstReaderT $ lift $ emitDecl hint ann expr + rawEmitDecl hint ann expr = liftSubstReaderT $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} instance (SinkableE e, ScopableBuilder r m) => ScopableBuilder r (OutReaderT e m) where @@ -555,6 +530,10 @@ instance (SinkableE e, Builder r m) => Builder r (ReaderT1 e m) where ReaderT1 $ lift $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} +instance (DiffStateE s d, Builder r m) => Builder r (DiffStateT1 s d m) where + rawEmitDecl hint ann expr = lift11 $ rawEmitDecl hint ann expr + {-# INLINE rawEmitDecl #-} + instance (SinkableE e, HoistableState e, Builder r m) => Builder r (StateT1 e m) where rawEmitDecl hint ann expr = lift11 $ emitDecl hint ann expr {-# INLINE rawEmitDecl #-} @@ -618,10 +597,11 @@ newtype WrapWithEmits n r = -- === lambda-like things === buildBlock - :: ScopableBuilder r m - => (forall l. (Emits l, DExt n l) => m l (Atom r l)) - -> m n (Block r n) -buildBlock = buildScoped + :: (ScopableBuilder r m, HasNamesE e, ToExpr e r) + => (forall l. (Emits l, DExt n l) => m l (e l)) + -> m n (Expr r n) +buildBlock cont = mkBlock =<< buildScoped cont +{-# INLINE buildBlock #-} buildCoreLam :: ScopableBuilder CoreIR m @@ -684,7 +664,7 @@ buildLamExpr (Abs bs UnitE) cont = case bs of Empty -> LamExpr Empty <$> buildBlock (cont []) Nest b rest -> do Abs b' (LamExpr bs' body') <- buildAbs (getNameHint b) (binderType b) \v -> do - rest' <- applySubst (b@>SubstVal (Var v)) $ EmptyAbs rest + rest' <- applySubst (b@>SubstVal (toAtom v)) $ EmptyAbs rest buildLamExpr rest' \vs -> cont $ sink v : vs return $ LamExpr (Nest b' bs') body' @@ -721,9 +701,9 @@ buildCaseAlts scrut indexedAltBody = do injectAltResult :: EnvReader m => [SType n] -> Int -> Alt SimpIR n -> m n (Alt SimpIR n) injectAltResult sumTys con (Abs b body) = liftBuilder do buildAlt (binderType b) \v -> do - originalResult <- emitBlock =<< applySubst (b@>SubstVal (Var v)) body - (dataResult, nonDataResult) <- fromPair originalResult - return $ PairVal dataResult $ Con $ SumCon (sinkList sumTys) con nonDataResult + originalResult <- emit =<< applySubst (b@>SubstVal (toAtom v)) body + (dataResult, nonDataResult) <- fromPairReduced originalResult + return $ toAtom $ ProdCon [dataResult, Con $ SumCon (sinkList sumTys) con nonDataResult] -- TODO: consider a version with nonempty list of alternatives where we figure -- out the result type from one of the alts rather than providing it explicitly @@ -731,27 +711,26 @@ buildCase' :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n -> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l)) -> m n (Expr r n) -buildCase' scrut resultTy indexedAltBody = do - case trySelectBranch scrut of - Just (i, arg) -> do - Distinct <- getDistinct - Atom <$> indexedAltBody i (sink arg) - Nothing -> do - scrutTy <- return $ getType scrut - altBinderTys <- caseAltsBinderTys scrutTy - (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do - (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do - blk <- buildBlock $ indexedAltBody i $ Var $ sink x - EffTy eff _ <- blockEffTy blk - return $ blk `PairE` eff - return (Abs b' body, ignoreHoistFailure $ hoist b' eff') - return $ Case scrut alts $ EffTy (mconcat effs) resultTy +buildCase' scrut resultTy indexedAltBody = case scrut of + Con con -> do + SumCon _ i arg <- return con + Distinct <- getDistinct + Atom <$> indexedAltBody i (sink arg) + Stuck _ _ -> do + scrutTy <- return $ getType scrut + altBinderTys <- caseAltsBinderTys scrutTy + (alts, effs) <- unzip <$> forM (enumerate altBinderTys) \(i, bTy) -> do + (Abs b' (body `PairE` eff')) <- buildAbs noHint bTy \x -> do + blk <- buildBlock $ indexedAltBody i $ toAtom $ sink x + return $ blk `PairE` getEffects blk + return (Abs b' body, ignoreHoistFailure $ hoist b' eff') + return $ Case scrut alts $ EffTy (mconcat effs) resultTy buildCase :: (Emits n, ScopableBuilder r m) => Atom r n -> Type r n -> (forall l. (Emits l, DExt n l) => Int -> Atom r l -> m l (Atom r l)) -> m n (Atom r n) -buildCase s r b = emitExprToAtom =<< buildCase' s r b +buildCase s r b = emit =<< buildCase' s r b buildEffLam :: ScopableBuilder r m @@ -759,46 +738,48 @@ buildEffLam -> (forall l. (Emits l, DExt n l) => AtomVar r l -> AtomVar r l -> m l (Atom r l)) -> m n (LamExpr r n) buildEffLam hint ty body = do - withFreshBinder noHint (TC HeapType) \h -> do - let ty' = RefTy (Var $ binderVar h) (sink ty) + withFreshBinder noHint (TyCon HeapType) \h -> do + let ty' = RefTy (toAtom $ binderVar h) (sink ty) withFreshBinder hint ty' \b -> do let ref = binderVar b hVar <- sinkM $ binderVar h body' <- buildBlock $ body (sink hVar) $ sink ref return $ LamExpr (BinaryNest h b) body' -buildForAnn - :: (Emits n, ScopableBuilder r m) +emitHof :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHof hof = mkTypedHof hof >>= emit + +mkTypedHof :: (EnvReader m, IRRep r) => Hof r n -> m n (TypedHof r n) +mkTypedHof hof = do + effTy <- effTyOfHof hof + return $ TypedHof effTy hof + +mkFor + :: (ScopableBuilder r m) => NameHint -> ForAnn -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) - -> m n (Atom r n) -buildForAnn hint ann (IxType iTy ixDict) body = do + -> m n (Expr r n) +mkFor hint ann (IxType iTy ixDict) body = do lam <- withFreshBinder hint iTy \b -> do let v = binderVar b body' <- buildBlock $ body $ sink v return $ LamExpr (UnaryNest b) body' - emitHof $ For ann (IxType iTy ixDict) lam + liftM toExpr $ mkTypedHof $ For ann (IxType iTy ixDict) lam buildFor :: (Emits n, ScopableBuilder r m) => NameHint -> Direction -> IxType r n -> (forall l. (Emits l, DExt n l) => AtomVar r l -> m l (Atom r l)) -> m n (Atom r n) -buildFor hint dir ty body = buildForAnn hint dir ty body +buildFor hint ann ty body = mkFor hint ann ty body >>= emit -buildMap :: (Emits n, ScopableBuilder r m) - => Atom r n - -> (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) - -> m n (Atom r n) +buildMap :: (Emits n, ScopableBuilder SimpIR m) + => SAtom n + -> (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l)) + -> m n (SAtom n) buildMap xs f = do - TabPi t <- return $ getType xs + TabPi t <- return $ getTyCon xs buildFor noHint Fwd (tabIxType t) \i -> - tabApp (sink xs) (Var i) >>= f - -unzipTab :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n, Atom r n) -unzipTab tab = do - fsts <- liftEmitBuilder $ buildMap tab getFst - snds <- liftEmitBuilder $ buildMap tab getSnd - return (fsts, snds) + tabApp (sink xs) (toAtom i) >>= f emitRunWriter :: (Emits n, ScopableBuilder r m) @@ -834,7 +815,7 @@ emitSeq :: (Emits n, ScopableBuilder SimpIR m) -> m n (Atom SimpIR n) emitSeq d t x f = do op <- mkSeq d t x f - emitExpr $ PrimOp $ DAMOp op + emit $ PrimOp $ DAMOp op mkSeq :: EnvReader m => Direction -> IxType SimpIR n -> Atom SimpIR n -> LamExpr SimpIR n @@ -851,19 +832,29 @@ buildRememberDest hint dest cont = do ty <- return $ getType dest doit <- buildUnaryLamExpr hint ty cont effs <- functionEffs doit - emitExpr $ PrimOp $ DAMOp $ RememberDest effs dest doit + emit $ PrimOp $ DAMOp $ RememberDest effs dest doit -- === vector space (ish) type class === +emitLin :: (Builder r m, ToExpr e r, Emits n) => e n -> m n (Atom r n) +emitLin e = case toExpr e of + Atom x -> return x + expr -> liftM toAtom $ emitDecl noHint LinearLet $ peepholeExpr expr +{-# INLINE emitLin #-} + +emitHofLin :: (Builder r m, Emits n) => Hof r n -> m n (Atom r n) +emitHofLin hof = mkTypedHof hof >>= emitLin +{-# INLINE emitHofLin #-} + zeroAt :: (Emits n, SBuilder m) => SType n -> m n (SAtom n) zeroAt ty = liftEmitBuilder $ go ty where go :: Emits n => SType n -> BuilderM SimpIR n (SAtom n) - go = \case - BaseTy bt -> return $ Con $ Lit $ zeroLit bt - ProdTy tys -> ProdVal <$> mapM go tys - TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> - go =<< instantiate (sink tabPi) [Var i] - _ -> unreachable + go (TyCon con) = case con of + BaseType bt -> return $ Con $ Lit $ zeroLit bt + ProdType tys -> toAtom . ProdCon <$> mapM go tys + TabPi tabPi -> buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> + go =<< instantiate (sink tabPi) [toAtom i] + _ -> unreachable zeroLit bt = case bt of Scalar Float64Type -> Float64Lit 0.0 Scalar Float32Type -> Float32Lit 0.0 @@ -884,15 +875,15 @@ maybeTangentType ty = liftEnvReaderT $ maybeTangentType' ty maybeTangentType' :: IRRep r => Type r n -> EnvReaderT Maybe n (Type r n) maybeTangentType' ty = case ty of - TabTy d b bodyTy -> do - refreshAbs (Abs b bodyTy) \b' bodyTy' -> do - bodyTanTy <- rec bodyTy' - return $ TabTy d b' bodyTanTy - TC con -> case con of - BaseType (Scalar Float64Type) -> return $ TC con - BaseType (Scalar Float32Type) -> return $ TC con + TyCon con -> case con of + TabPi (TabPiType d b bodyTy) -> do + refreshAbs (Abs b bodyTy) \b' bodyTy' -> do + bodyTanTy <- rec bodyTy' + return $ TabTy d b' bodyTanTy + BaseType (Scalar Float64Type) -> return $ toType con + BaseType (Scalar Float32Type) -> return $ toType con BaseType _ -> return $ UnitTy - ProdType tys -> ProdTy <$> traverse rec tys + ProdType tys -> toType . ProdType <$> traverse rec tys _ -> empty _ -> empty where rec = maybeTangentType' @@ -900,184 +891,138 @@ maybeTangentType' ty = case ty of tangentBaseMonoidFor :: (Emits n, SBuilder m) => SType n -> m n (BaseMonoid SimpIR n) tangentBaseMonoidFor ty = do zero <- zeroAt ty - adder <- liftBuilder $ buildBinaryLamExpr (noHint, ty) (noHint, ty) \x y -> addTangent (Var x) (Var y) + adder <- liftBuilder $ buildBinaryLamExpr (noHint, ty) (noHint, ty) \x y -> + addTangent (toAtom x) (toAtom y) return $ BaseMonoid zero adder addTangent :: (Emits n, SBuilder m) => SAtom n -> SAtom n -> m n (SAtom n) addTangent x y = do - case getType x of + case getTyCon x of + BaseType (Scalar _) -> emit $ BinOp FAdd x y + ProdType _ -> do + xs <- getUnpacked x + ys <- getUnpacked y + toAtom . ProdCon <$> zipWithM addTangent xs ys TabPi t -> liftEmitBuilder $ buildFor (getNameHint t) Fwd (tabIxType t) \i -> do - bindM2 addTangent (tabApp (sink x) (Var i)) (tabApp (sink y) (Var i)) - TC con -> case con of - BaseType (Scalar _) -> emitOp $ BinOp FAdd x y - ProdType _ -> do - xs <- getUnpacked x - ys <- getUnpacked y - ProdVal <$> zipWithM addTangent xs ys - ty -> notTangent ty + bindM2 addTangent (tabApp (sink x) (toAtom i)) (tabApp (sink y) (toAtom i)) ty -> notTangent ty where notTangent ty = error $ "Not a tangent type: " ++ pprint ty symbolicTangentTy :: (EnvReader m, Fallible1 m) => CType n -> m n (CType n) symbolicTangentTy elTy = lookupSourceMap "SymbolicTangent" >>= \case Just (UTyConVar symTanName) -> do - return $ TypeCon "SymbolicTangent" symTanName $ - TyConParams [Explicit] [Type elTy] - Nothing -> throw UnboundVarErr $ + return $ toType $ UserADTType "SymbolicTangent" symTanName $ + TyConParams [Explicit] [toAtom elTy] + Nothing -> throwInternal $ "Can't define a custom linearization with symbolic zeros: " ++ "the SymbolicTangent type is not in scope." - Just _ -> throw TypeErr "SymbolicTangent should name a `data` type" + Just _ -> throwInternal $ "SymbolicTangent should name a `data` type" symbolicTangentZero :: EnvReader m => SType n -> m n (SAtom n) -symbolicTangentZero argTy = return $ SumVal [UnitTy, argTy] 0 UnitVal +symbolicTangentZero argTy = return $ toAtom $ SumCon [UnitTy, argTy] 0 UnitVal symbolicTangentNonZero :: EnvReader m => SAtom n -> m n (SAtom n) symbolicTangentNonZero val = do ty <- return $ getType val - return $ SumVal [UnitTy, ty] 1 val + return $ toAtom $ SumCon [UnitTy, ty] 1 val -- === builder versions of common local ops === -fLitLike :: (Builder r m, Emits n) => Double -> Atom r n -> m n (Atom r n) -fLitLike x t = do - ty <- return $ getType t - case ty of - BaseTy (Scalar Float64Type) -> return $ Con $ Lit $ Float64Lit x - BaseTy (Scalar Float32Type) -> return $ Con $ Lit $ Float32Lit $ realToFrac x - _ -> error "Expected a floating point scalar" +fadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fadd x y = emit $ BinOp FAdd x y -neg :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -neg x = emitOp $ UnOp FNeg x +fsub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fsub x y = emit $ BinOp FSub x y -add :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -add x y = emitOp $ BinOp FAdd x y +fmul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fmul x y = emit $ BinOp FMul x y --- TODO: Implement constant folding for fixed-width integer types as well! -iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -iadd (Con (Lit l)) y | getIntLit l == 0 = return y -iadd x (Con (Lit l)) | getIntLit l == 0 = return x -iadd x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (+) x y -iadd x y = emitOp $ BinOp IAdd x y +fdiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +fdiv x y = emit $ BinOp FDiv x y -mul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -mul x y = emitOp $ BinOp FMul x y +iadd :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) +iadd x y = emit $ BinOp IAdd x y imul :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -imul (Con (Lit l)) y | getIntLit l == 1 = return y -imul x (Con (Lit l)) | getIntLit l == 1 = return x -imul x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (*) x y -imul x y = emitOp $ BinOp IMul x y - -sub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -sub x y = emitOp $ BinOp FSub x y - -isub :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -isub x (Con (Lit l)) | getIntLit l == 0 = return x -isub x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp (-) x y -isub x y = emitOp $ BinOp ISub x y - -select :: (Builder r m, Emits n) => Atom r n -> Atom r n -> Atom r n -> m n (Atom r n) -select (Con (Lit (Word8Lit p))) x y = return $ if p /= 0 then x else y -select p x y = emitOp $ MiscOp $ Select p x y - -div' :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -div' x y = emitOp $ BinOp FDiv x y - -idiv :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -idiv x (Con (Lit l)) | getIntLit l == 1 = return x -idiv x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntBinOp div x y -idiv x y = emitOp $ BinOp IDiv x y - -irem :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -irem x y = emitOp $ BinOp IRem x y +imul x y = emit $ BinOp IMul x y -fpow :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -fpow x y = emitOp $ BinOp FPow x y +fLitLike :: Double -> SAtom n -> SAtom n +fLitLike x t = case getTyCon t of + BaseType (Scalar Float64Type) -> toAtom $ Lit $ Float64Lit x + BaseType (Scalar Float32Type) -> toAtom $ Lit $ Float32Lit $ realToFrac x + _ -> error "Expected a floating point scalar" -flog :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) -flog x = emitOp $ UnOp Log x - -ilt :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -ilt x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (<) x y -ilt x y = emitOp $ BinOp (ICmp Less) x y - -ieq :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -ieq x@(Con (Lit _)) y@(Con (Lit _)) = return $ applyIntCmpOp (==) x y -ieq x y = emitOp $ BinOp (ICmp Equal) x y - -fromPair :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n (Atom r n, Atom r n) +fromPair :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n, Atom r n) fromPair pair = do getUnpacked pair >>= \case [x, y] -> return (x, y) _ -> error "expected a pair" -getFst :: Builder r m => Atom r n -> m n (Atom r n) -getFst p = fst <$> fromPair p - -getSnd :: Builder r m => Atom r n -> m n (Atom r n) -getSnd p = snd <$> fromPair p - -- the rightmost index is applied first -getNaryProjRef :: (Builder r m, Emits n) => [Projection] -> Atom r n -> m n (Atom r n) -getNaryProjRef [] ref = return ref -getNaryProjRef (i:is) ref = getProjRef i =<< getNaryProjRef is ref +applyProjectionsRef :: (Builder r m, Emits n) => [Projection] -> Atom r n -> m n (Atom r n) +applyProjectionsRef [] ref = return ref +applyProjectionsRef (i:is) ref = getProjRef i =<< applyProjectionsRef is ref getProjRef :: (Builder r m, Emits n) => Projection -> Atom r n -> m n (Atom r n) -getProjRef i r = emitOp =<< mkProjRef r i +getProjRef i r = emit =<< mkProjRef r i -- XXX: getUnpacked must reduce its argument to enforce the invariant that -- ProjectElt atoms are always fully reduced (to avoid type errors between two -- equivalent types spelled differently). -getUnpacked :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n [Atom r n] -getUnpacked atom = do - atom' <- cheapNormalize atom - ty <- return $ getType atom' - positions <- case ty of - ProdTy tys -> return $ void tys - DepPairTy _ -> return [(), ()] - _ -> error $ "not a product type: " ++ pprint ty - forM (enumerate positions) \(i, _) -> - normalizeProj (ProjectProduct i) atom' +getUnpacked :: (Builder r m, Emits n) => Atom r n -> m n [Atom r n] +getUnpacked atom = forM (productIdxs atom) \i -> proj i atom {-# SCC getUnpacked #-} -getProj :: (Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) -getProj i atom = do - atom' <- cheapNormalize atom - normalizeProj (ProjectProduct i) atom' +productIdxs :: IRRep r => Atom r n -> [Int] +productIdxs atom = + let positions = case getType atom of + TyCon (ProdType tys) -> void tys + TyCon (DepPairTy _) -> [(), ()] + ty -> error $ "not a product type: " ++ pprint ty + in fst <$> enumerate positions -emitUnpacked :: (Builder r m, Emits n) => Atom r n -> m n [AtomVar r n] -emitUnpacked tup = do - xs <- getUnpacked tup - forM xs \x -> emit $ Atom x - -unwrapNewtype :: EnvReader m => CAtom n -> m n (CAtom n) -unwrapNewtype (NewtypeCon _ x) = return x +unwrapNewtype :: (Emits n, Builder CoreIR m) => CAtom n -> m n (CAtom n) +unwrapNewtype (Con (NewtypeCon _ x)) = return x unwrapNewtype x = case getType x of - NewtypeTyCon con -> do + TyCon (NewtypeTyCon con) -> do (_, ty) <- unwrapNewtypeType con - return $ ProjectElt ty UnwrapNewtype x + emit $ Unwrap ty x _ -> error "not a newtype" {-# INLINE unwrapNewtype #-} -projectTuple :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) -projectTuple i x = normalizeProj (ProjectProduct i) x +proj ::(Builder r m, Emits n) => Int -> Atom r n -> m n (Atom r n) +proj i = \case + Con con -> case con of + ProdCon xs -> return $ xs !! i + DepPair l _ _ | i == 0 -> return l + DepPair _ r _ | i == 1 -> return r + _ -> error "not a product" + x -> do + ty <- projType i x + emit $ Project ty i x + +getFst :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) +getFst = proj 0 -projectStruct :: EnvReader m => Int -> CAtom n -> m n (CAtom n) +getSnd :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) +getSnd = proj 1 + +projectStruct :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n) projectStruct i x = do projs <- getStructProjections i (getType x) - normalizeNaryProj projs x + applyProjections projs x {-# INLINE projectStruct #-} projectStructRef :: (Builder CoreIR m, Emits n) => Int -> CAtom n -> m n (CAtom n) projectStructRef i x = do RefTy _ valTy <- return $ getType x projs <- getStructProjections i valTy - getNaryProjRef projs x + applyProjectionsRef projs x {-# INLINE projectStructRef #-} getStructProjections :: EnvReader m => Int -> CType n -> m n [Projection] -getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do +getStructProjections i (TyCon (NewtypeTyCon (UserADTType _ tyConName _))) = do TyConDef _ _ _ ~(StructFields fields) <- lookupTyCon tyConName return case fields of [_] | i == 0 -> [UnwrapNewtype] @@ -1085,94 +1030,127 @@ getStructProjections i (NewtypeTyCon (UserADTType _ tyConName _)) = do _ -> [ProjectProduct i, UnwrapNewtype] getStructProjections _ _ = error "not a struct" +-- the rightmost index is applied first +applyProjections :: (Builder CoreIR m, Emits n) => [Projection] -> CAtom n -> m n (CAtom n) +applyProjections [] x = return x +applyProjections (p:ps) x = do + x' <- applyProjections ps x + case p of + ProjectProduct i -> proj i x' + UnwrapNewtype -> unwrapNewtype x' + +-- the rightmost index is applied first +applyProjectionsReduced :: EnvReader m => [Projection] -> CAtom n -> m n (CAtom n) +applyProjectionsReduced [] x = return x +applyProjectionsReduced (p:ps) x = do + x' <- applyProjectionsReduced ps x + case p of + ProjectProduct i -> reduceProj i x' + UnwrapNewtype -> reduceUnwrap x' + +mkBlock :: (EnvReader m, IRRep r) => ToExpr e r => Abs (Decls r) e n -> m n (Expr r n) +mkBlock (Abs decls body) = do + let block = Abs decls (toExpr body) + effTy <- blockEffTy block + return $ Block effTy block + +blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) +blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do + effs <- declsEffects decls mempty + return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result + where + declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) + declsEffects Empty !acc = return acc + declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do + expr' <- sinkM expr + declsEffects rest $ acc <> getEffects expr' + mkApp :: EnvReader m => CAtom n -> [CAtom n] -> m n (CExpr n) mkApp f xs = do et <- appEffTy (getType f) xs return $ App et f xs -mkTabApp :: (EnvReader m, IRRep r) => Atom r n -> [Atom r n] -> m n (Expr r n) +mkTabApp :: (EnvReader m, IRRep r) => Atom r n -> Atom r n -> m n (Expr r n) mkTabApp xs ixs = do ty <- typeOfTabApp (getType xs) ixs return $ TabApp ty xs ixs +mkProject :: (EnvReader m, IRRep r) => Int -> Atom r n -> m n (Expr r n) +mkProject i x = do + ty <- projType i x + return $ Project ty i x + mkTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (SExpr n) mkTopApp f xs = do resultTy <- typeOfTopApp f xs return $ TopApp resultTy f xs -mkApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (CExpr n) +mkApplyMethod :: EnvReader m => CDict n -> Int -> [CAtom n] -> m n (CExpr n) mkApplyMethod d i xs = do resultTy <- typeOfApplyMethod d i xs - return $ ApplyMethod resultTy d i xs + return $ ApplyMethod resultTy (toAtom d) i xs -mkDictAtom :: EnvReader m => DictExpr n -> m n (CAtom n) -mkDictAtom d = do - ty <- typeOfDictExpr d - return $ DictCon ty d +mkInstanceDict :: EnvReader m => InstanceName n -> [CAtom n] -> m n (CDict n) +mkInstanceDict instanceName args = do + instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName + PairE (ListE params) _ <- instantiate instanceDef args + ty <- toType <$> dictType className params + return $ toDict $ InstanceDict ty instanceName args mkCase :: (EnvReader m, IRRep r) => Atom r n -> Type r n -> [Alt r n] -> m n (Expr r n) mkCase scrut resultTy alts = liftEnvReaderM do eff' <- fold <$> forM alts \alt -> refreshAbs alt \b body -> do - EffTy eff _ <- blockEffTy body - return $ ignoreHoistFailure $ hoist b eff + return $ ignoreHoistFailure $ hoist b $ getEffects body return $ Case scrut alts (EffTy eff' resultTy) -mkCatchException :: EnvReader m => CBlock n -> m n (Hof CoreIR n) +mkCatchException :: EnvReader m => CExpr n -> m n (Hof CoreIR n) mkCatchException body = do - EffTy _ bodyTy <- blockEffTy body - resultTy <- makePreludeMaybeTy bodyTy + resultTy <- makePreludeMaybeTy (getType body) return $ CatchException resultTy body app :: (CBuilder m, Emits n) => CAtom n -> CAtom n -> m n (CAtom n) -app x i = mkApp x [i] >>= emitExpr +app x i = mkApp x [i] >>= emit naryApp :: (CBuilder m, Emits n) => CAtom n -> [CAtom n] -> m n (CAtom n) -naryApp = naryAppHinted noHint +naryApp f xs= mkApp f xs >>= emit {-# INLINE naryApp #-} naryTopApp :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n) -naryTopApp f xs = emitExpr =<< mkTopApp f xs +naryTopApp f xs = emit =<< mkTopApp f xs {-# INLINE naryTopApp #-} naryTopAppInlined :: (Builder SimpIR m, Emits n) => TopFunName n -> [SAtom n] -> m n (SAtom n) naryTopAppInlined f xs = do TopFunBinding f' <- lookupEnv f case f' of - DexTopFun _ lam _ -> instantiate lam xs >>= emitBlock + DexTopFun _ lam _ -> instantiate lam xs >>= emit _ -> naryTopApp f xs {-# INLINE naryTopAppInlined #-} -naryAppHinted :: (CBuilder m, Emits n) - => NameHint -> CAtom n -> [CAtom n] -> m n (CAtom n) -naryAppHinted hint f xs = Var <$> (mkApp f xs >>= emitHinted hint) - tabApp :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -tabApp x i = mkTabApp x [i] >>= emitExpr +tabApp x i = mkTabApp x i >>= emit naryTabApp :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) -naryTabApp = naryTabAppHinted noHint +naryTabApp f [] = return f +naryTabApp f (x:xs) = do + ans <- mkTabApp f x >>= emit + naryTabApp ans xs {-# INLINE naryTabApp #-} -naryTabAppHinted :: (Builder r m, Emits n) - => NameHint -> Atom r n -> [Atom r n] -> m n (Atom r n) -naryTabAppHinted hint f xs = do - expr <- mkTabApp f xs - Var <$> emitHinted hint expr - indexRef :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) -indexRef ref i = emitOp =<< mkIndexRef ref i +indexRef ref i = emit =<< mkIndexRef ref i naryIndexRef :: (Builder r m, Emits n) => Atom r n -> [Atom r n] -> m n (Atom r n) naryIndexRef ref is = foldM indexRef ref is ptrOffset :: (Builder r m, Emits n) => Atom r n -> Atom r n -> m n (Atom r n) ptrOffset x (IdxRepVal 0) = return x -ptrOffset x i = emitOp $ MemOp $ PtrOffset x i +ptrOffset x i = emit $ MemOp $ PtrOffset x i {-# INLINE ptrOffset #-} unsafePtrLoad :: (Builder r m, Emits n) => Atom r n -> m n (Atom r n) unsafePtrLoad x = do - body <- liftEmitBuilder $ buildBlock $ emitOp . MemOp . PtrLoad =<< sinkM x + body <- liftEmitBuilder $ buildBlock $ emit . MemOp . PtrLoad =<< sinkM x emitHof $ RunIO body mkIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Atom r n -> Atom r n -> m n (PrimOp r n) @@ -1188,16 +1166,16 @@ mkProjRef ref i = do -- === index set type class === applyIxMethod :: (SBuilder m, Emits n) => IxDict SimpIR n -> IxMethod -> [SAtom n] -> m n (SAtom n) -applyIxMethod dict method args = case dict of +applyIxMethod (DictCon dict) method args = case dict of -- These cases are used in SimpIR and they work with IdxRepVal - IxDictRawFin n -> case method of + IxRawFin n -> case method of Size -> do [] <- return args; return n Ordinal -> do [i] <- return args; return i UnsafeFromOrdinal -> do [i] <- return args; return i - IxDictSpecialized _ d params -> do + IxSpecialized d params -> do SpecializedDict _ maybeFs <- lookupSpecDict d Just fs <- return maybeFs - instantiate (fs !! fromEnum method) (params ++ args) >>= emitBlock + instantiate (fs !! fromEnum method) (params ++ args) >>= emit unsafeFromOrdinal :: (SBuilder m, Emits n) => IxType SimpIR n -> Atom SimpIR n -> m n (Atom SimpIR n) unsafeFromOrdinal (IxType _ dict) i = applyIxMethod dict UnsafeFromOrdinal [i] @@ -1211,9 +1189,8 @@ indexSetSize (IxType _ dict) = applyIxMethod dict Size [] -- === core versions of index set type class === applyIxMethodCore :: (CBuilder m, Emits n) => IxMethod -> IxType CoreIR n -> [CAtom n] -> m n (CAtom n) -applyIxMethodCore method (IxType _ (IxDictAtom dict)) args = do - emitExpr =<< mkApplyMethod dict (fromEnum method) args -applyIxMethodCore _ _ _ = error "not an ix type" +applyIxMethodCore method (IxType _ dict) args = + emit =<< mkApplyMethod dict (fromEnum method) args -- === pseudo-prelude === @@ -1229,7 +1206,7 @@ emitIf :: (Emits n, ScopableBuilder r m) -> (forall l. (Emits l, DExt n l) => m l (Atom r l)) -> m n (Atom r n) emitIf predicate resultTy trueCase falseCase = do - predicate' <- emitOp $ MiscOp $ ToEnum (SumTy [UnitTy, UnitTy]) predicate + predicate' <- emit $ ToEnum (TyCon (SumType [UnitTy, UnitTy])) predicate buildCase predicate' resultTy \i _ -> case i of 0 -> falseCase @@ -1253,7 +1230,7 @@ fromJustE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) fromJustE x = liftEmitBuilder do MaybeTy a <- return $ getType x emitMaybeCase x a - (emitOp $ MiscOp $ ThrowError $ sink a) + (emit $ MiscOp $ ThrowError $ sink a) (return) -- Maybe a -> Bool @@ -1262,31 +1239,31 @@ isJustE x = liftEmitBuilder $ emitMaybeCase x BoolTy (return FalseAtom) (\_ -> return TrueAtom) -- Monoid a -> (n=>a) -> a -reduceE :: (Emits n, Builder r m) => BaseMonoid r n -> Atom r n -> m n (Atom r n) +reduceE :: (Emits n, SBuilder m) => BaseMonoid SimpIR n -> SAtom n -> m n (SAtom n) reduceE monoid xs = liftEmitBuilder do - TabPi tabPi <- return $ getType xs + TabPi tabPi <- return $ getTyCon xs let a = assumeConst tabPi getSnd =<< emitRunWriter noHint a monoid \_ ref -> buildFor noHint Fwd (sink $ tabIxType tabPi) \i -> do - x <- tabApp (sink xs) (Var i) - emitExpr $ PrimOp $ RefOp (sink $ Var ref) $ MExtend (sink monoid) x + x <- tabApp (sink xs) (toAtom i) + emit $ PrimOp $ RefOp (sink $ toAtom ref) $ MExtend (sink monoid) x andMonoid :: (EnvReader m, IRRep r) => m n (BaseMonoid r n) andMonoid = liftM (BaseMonoid TrueAtom) $ liftBuilder $ buildBinaryLamExpr (noHint, BoolTy) (noHint, BoolTy) \x y -> - emitOp $ BinOp BAnd (sink $ Var x) (Var y) + emit $ BinOp BAnd (sink $ toAtom x) (toAtom y) -- (a-> {|eff} b) -> n=>a -> {|eff} (n=>b) -mapE :: (Emits n, ScopableBuilder r m) - => (forall l. (Emits l, DExt n l) => Atom r l -> m l (Atom r l)) - -> Atom r n -> m n (Atom r n) +mapE :: (Emits n, ScopableBuilder SimpIR m) + => (forall l. (Emits l, DExt n l) => SAtom l -> m l (SAtom l)) + -> SAtom n -> m n (SAtom n) mapE cont xs = do - TabPi tabPi <- return $ getType xs + TabPi tabPi <- return $ getTyCon xs buildFor (getNameHint tabPi) Fwd (tabIxType tabPi) \i -> do - tabApp (sink xs) (Var i) >>= cont + tabApp (sink xs) (toAtom i) >>= cont -- (n:Type) ?-> (a:Type) ?-> (xs : n=>Maybe a) : Maybe (n => a) = -catMaybesE :: (Emits n, Builder r m) => Atom r n -> m n (Atom r n) +catMaybesE :: (Emits n, SBuilder m) => SAtom n -> m n (SAtom n) catMaybesE maybes = do TabTy d n (MaybeTy a) <- return $ getType maybes justs <- liftEmitBuilder $ mapE isJustE maybes @@ -1325,7 +1302,7 @@ runMaybeWhile body = do emitWhile do ans <- body emitMaybeCase ans Word8Ty - (emit (PrimOp $ RefOp (sink $ Var ref) $ MPut TrueAtom) >> return FalseAtom) + (emit (RefOp (sink $ toAtom ref) $ MPut TrueAtom) >> return FalseAtom) (return) return UnitVal emitIf hadError (MaybeTy UnitTy) @@ -1390,7 +1367,7 @@ telescopicCapture bs e = do let vsTysSorted = toposortAnnVars $ zip vs vTys let vsSorted = map fst vsTysSorted ty <- liftEnvReaderM $ buildTelescopeTy vsTysSorted - valsSorted <- forM vsSorted \v -> Var <$> toAtomVar v + valsSorted <- forM vsSorted \v -> toAtom <$> toAtomVar v result <- buildTelescopeVal valsSorted ty reconAbs <- return $ ignoreHoistFailure $ hoist bs do case abstractFreeVarsNoAnn vsSorted e of @@ -1430,19 +1407,19 @@ buildTelescopeVal xsTop tyTop = fst <$> go tyTop xsTop where go ty rest = case ty of ProdTelescope tys -> do (xs, rest') <- return $ splitAt (length tys) rest - return (ProdVal xs, rest') + return (toAtom $ ProdCon xs, rest') DepTelescope ty1 (Abs b ty2) -> do (x1, ~(xDep : rest')) <- go ty1 rest ty2' <- applySubst (b@>SubstVal xDep) ty2 (x2, rest'') <- go ty2' rest' let depPairTy = DepPairType ExplicitDepPair b (telescopeTypeType ty2) - return (PairVal x1 (DepPair xDep x2 depPairTy), rest'') + return (toAtom $ ProdCon [x1, toAtom $ DepPair xDep x2 depPairTy], rest'') telescopeTypeType :: TelescopeType (AtomNameC r) (Type r) n -> Type r n -telescopeTypeType (ProdTelescope tys) = ProdTy tys +telescopeTypeType (ProdTelescope tys) = toType $ ProdType tys telescopeTypeType (DepTelescope lhs (Abs b rhs)) = do let lhs' = telescopeTypeType lhs - let rhs' = DepPairTy (DepPairType ExplicitDepPair b (telescopeTypeType rhs)) + let rhs' = toType $ DepPairTy (DepPairType ExplicitDepPair b (telescopeTypeType rhs)) PairTy lhs' rhs' unpackTelescope @@ -1452,14 +1429,20 @@ unpackTelescope (ReconBinders tyTop _) xTop = go tyTop xTop where go :: (Fallible1 m, EnvReader m, IRRep r) => TelescopeType c e l-> Atom r n -> m n [Atom r n] go ty x = case ty of - ProdTelescope _ -> getUnpacked x + ProdTelescope _ -> getUnpackedReduced x DepTelescope ty1 (Abs _ ty2) -> do - (x1, xPair) <- fromPair x - (xDep, x2) <- fromPair xPair + (x1, xPair) <- fromPairReduced x + (xDep, x2) <- fromPairReduced xPair xs1 <- go ty1 x1 xs2 <- go ty2 x2 return $ xs1 ++ (xDep : xs2) +fromPairReduced :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n (Atom r n, Atom r n) +fromPairReduced pair = (,) <$> reduceProj 0 pair <*> reduceProj 1 pair + +getUnpackedReduced :: (Fallible1 m, EnvReader m, IRRep r) => Atom r n -> m n [Atom r n] +getUnpackedReduced xs = forM (productIdxs xs) \i -> reduceProj i xs + -- sorts name-annotation pairs so that earlier names may be occur free in later -- annotations but not vice versa. type AnnVar c e n = (Name c n, e n) @@ -1509,10 +1492,8 @@ type ExprVisitorNoEmits2 m r = forall i o. ExprVisitorNoEmits (m i o) r i o visitLamNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) => LamExpr r i -> m i o (LamExpr r o) -visitLamNoEmits (LamExpr bs (Abs decls result)) = - visitBinders bs \bs' -> LamExpr bs' <$> - visitDeclsNoEmits decls \decls' -> Abs decls' <$> do - visitAtom result +visitLamNoEmits (LamExpr bs body) = + visitBinders bs \bs' -> LamExpr bs' <$> visitExprNoEmits body visitDeclsNoEmits :: (ExprVisitorNoEmits2 m r, IRRep r, AtomSubstReader v m, EnvExtender2 m) @@ -1550,12 +1531,7 @@ visitLamEmits :: (ExprVisitorEmits2 m r, IRRep r, SubstReader AtomSubstVal m, ScopableBuilder2 r m) => LamExpr r i -> m i o (LamExpr r o) visitLamEmits (LamExpr bs body) = visitBinders bs \bs' -> LamExpr bs' <$> - buildBlock (visitBlockEmits body) - -visitBlockEmits - :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) - => Block r i -> m i o (Atom r o) -visitBlockEmits (Abs decls result) = visitDeclsEmits decls $ visitAtom result + buildBlock (visitExprEmits body) visitDeclsEmits :: (ExprVisitorEmits2 m r, SubstReader AtomSubstVal m, EnvExtender2 m, IRRep r, Emits o) @@ -1567,30 +1543,3 @@ visitDeclsEmits (Nest (Let b (DeclBinding _ expr)) decls) cont = do x <- visitExprEmits expr extendSubst (b@>SubstVal x) do visitDeclsEmits decls cont - --- === Helpers for function evaluation over fixed-width types === - -applyIntBinOp' :: (forall a. (Eq a, Ord a, Num a, Integral a) - => (a -> Atom r n) -> a -> a -> Atom r n) -> Atom r n -> Atom r n -> Atom r n -applyIntBinOp' f x y = case (x, y) of - (Con (Lit (Int64Lit xv)), Con (Lit (Int64Lit yv))) -> f (Con . Lit . Int64Lit) xv yv - (Con (Lit (Int32Lit xv)), Con (Lit (Int32Lit yv))) -> f (Con . Lit . Int32Lit) xv yv - (Con (Lit (Word8Lit xv)), Con (Lit (Word8Lit yv))) -> f (Con . Lit . Word8Lit) xv yv - (Con (Lit (Word32Lit xv)), Con (Lit (Word32Lit yv))) -> f (Con . Lit . Word32Lit) xv yv - (Con (Lit (Word64Lit xv)), Con (Lit (Word64Lit yv))) -> f (Con . Lit . Word64Lit) xv yv - _ -> error "Expected integer atoms" - -applyIntBinOp :: (forall a. (Num a, Integral a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n -applyIntBinOp f x y = applyIntBinOp' (\w -> w ... f) x y - -applyIntCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> Atom r n -> Atom r n -> Atom r n -applyIntCmpOp f x y = applyIntBinOp' (\_ -> (Con . Lit . Word8Lit . fromIntegral . fromEnum) ... f) x y - -applyFloatBinOp :: (forall a. (Num a, Fractional a) => a -> a -> a) -> Atom r n -> Atom r n -> Atom r n -applyFloatBinOp f x y = case (x, y) of - (Con (Lit (Float64Lit xv)), Con (Lit (Float64Lit yv))) -> Con $ Lit $ Float64Lit $ f xv yv - (Con (Lit (Float32Lit xv)), Con (Lit (Float32Lit yv))) -> Con $ Lit $ Float32Lit $ f xv yv - _ -> error "Expected float atoms" - -_applyFloatUnOp :: (forall a. (Num a, Fractional a) => a -> a) -> Atom r n -> Atom r n -_applyFloatUnOp f x = applyFloatBinOp (\_ -> f) (error "shouldn't be needed") x diff --git a/src/lib/Cat.hs b/src/lib/Cat.hs deleted file mode 100644 index 7c1e49132..000000000 --- a/src/lib/Cat.hs +++ /dev/null @@ -1,178 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} - -module Cat (CatT, MonadCat, runCatT, look, extend, scoped, looks, extendLocal, - extendR, captureW, asFst, asSnd, capture, asCat, evalCatT, evalCat, - Cat, runCat, newCatT, catTraverse, evalScoped, execCat, execCatT, - catFold, catFoldM, catMap, catMapM) where - --- Monad for tracking monoidal state - -import Control.Applicative -import Control.Monad.State.Strict -import Control.Monad.Reader -import Control.Monad.Writer -import Control.Monad.Identity -import Control.Monad.Except hiding (Except) - -import Err - -newtype CatT env m a = CatT (StateT (env, env) m a) - deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadFail, Alternative, - Fallible) - -type Cat env = CatT env Identity - -class (Monoid env, Monad m) => MonadCat env m | m -> env where - look :: m env - extend :: env -> m () - scoped :: m a -> m (a, env) - -instance (Monoid env, Monad m) => MonadCat env (CatT env m) where - look = CatT $ gets fst - extend x = CatT $ do - (fullState, localState) <- get - put (fullState <> x, localState <> x) - scoped (CatT m) = CatT $ do - originalState <- get - put (fst originalState, mempty) - ans <- m - newLocalState <- gets snd - put originalState - return (ans, newLocalState) - -instance MonadCat env m => MonadCat env (StateT s m) where - look = lift look - extend x = lift $ extend x - scoped m = StateT \s -> do - ((ans, s'), env) <- scoped $ runStateT m s - return $ ((ans, env), s') - -instance MonadCat env m => MonadCat env (ReaderT r m) where - look = lift look - extend x = lift $ extend x - scoped m = do r <- ask - lift $ scoped $ runReaderT m r - -instance (Monoid w, MonadCat env m) => MonadCat env (WriterT w m) where - look = lift look - extend x = lift $ extend x - scoped m = do ((x, w), env) <- lift $ scoped $ runWriterT m - tell w - return (x, env) - -instance MonadCat env m => MonadCat env (ExceptT e m) where - look = lift look - extend x = lift $ extend x - scoped m = do (xerr, env) <- lift $ scoped $ runExceptT m - case xerr of - Left err -> throwError err - Right x -> return (x, env) - -instance (Monoid env, MonadReader r m) => MonadReader r (CatT env m) where - ask = lift ask - local f m = do - env <- look - (ans, env') <- lift $ local f $ runCatT m env - extend env' - return ans - -runCatT :: (Monoid env, Monad m) => CatT env m a -> env -> m (a, env) -runCatT (CatT m) initEnv = do - (ans, (_, newEnv)) <- runStateT m (initEnv, mempty) - return (ans, newEnv) - -evalCatT :: (Monoid env, Monad m) => CatT env m a -> m a -evalCatT m = fst <$> runCatT m mempty - -execCatT :: (Monoid env, Monad m) => CatT env m a -> m env -execCatT m = snd <$> runCatT m mempty - -newCatT :: (Monoid env, Monad m) => (env -> m (a, env)) -> CatT env m a -newCatT f = do - env <- look - (ans, env') <- lift $ f env - extend env' - return ans - -runCat :: Monoid env => Cat env a -> env -> (a, env) -runCat m env = runIdentity $ runCatT m env - -evalCat :: Monoid env => Cat env a -> a -evalCat m = runIdentity $ evalCatT m - -execCat :: Monoid env => Cat env a -> env -execCat m = runIdentity $ execCatT m - -looks :: (Monoid env, MonadCat env m) => (env -> a) -> m a -looks getter = liftM getter look - -evalScoped :: Monoid env => Cat env a -> Cat env a -evalScoped m = fst <$> scoped m - -capture :: (Monoid env, MonadCat env m) => m a -> m (a, env) -capture m = do - (x, env) <- scoped m - extend env - return (x, env) - -extendLocal :: (Monoid env, MonadCat env m) => env -> m a -> m a -extendLocal x m = do - ((ans, env), _) <- scoped $ do extend x - scoped m - extend env - return ans - --- Not part of the cat monad, but related utils for monoidal envs - -catTraverse :: (Monoid menv, MonadReader env m, Traversable f) - => (a -> m (b, menv)) -> (menv -> env) -> f a -> menv -> m (f b, menv) -catTraverse f inj xs env = runCatT (traverse (asCat f inj) xs) env - -catFoldM :: (Monoid env, Traversable t, Monad m) - => (env -> a -> m env) -> env -> t a -> m env -catFoldM f env xs = liftM snd $ flip runCatT env $ forM_ xs \x -> do - cur <- look - new <- lift $ f cur x - extend new - -catFold :: (Monoid env, Traversable t) - => (env -> a -> env) -> env -> t a -> env -catFold f env xs = runIdentity $ catFoldM (\e x -> Identity $ f e x) env xs - -catMapM :: (Monoid env, Traversable t, Monad m) - => (env -> a -> m (b, env)) -> env -> t a -> m (t b, env) -catMapM f env xs = flip runCatT env $ forM xs \x -> do - cur <- look - (y, new) <- lift $ f cur x - extend new - return y - -catMap :: (Monoid env, Traversable t) - => (env -> a -> (b, env)) -> env -> t a -> (t b, env) -catMap f env xs = runIdentity $ catMapM (\e x -> Identity $ f e x) env xs - -asCat :: (Monoid menv, MonadReader env m) - => (a -> m (b, menv)) -> (menv -> env) -> a -> CatT menv m b -asCat f inj x = do - env' <- look - (x', env'') <- lift $ local (const $ inj env') (f x) - extend env'' - return x' - -extendR :: (Monoid env, MonadReader env m) => env -> m a -> m a -extendR x m = local (<> x) m - -asFst :: Monoid b => a -> (a, b) -asFst x = (x, mempty) - -asSnd :: Monoid a => b -> (a, b) -asSnd y = (mempty, y) - -captureW :: MonadWriter w m => m a -> m (a, w) -captureW m = censor (const mempty) (listen m) diff --git a/src/lib/CheapReduction.hs b/src/lib/CheapReduction.hs index c4cc41bb1..41bd2f1d4 100644 --- a/src/lib/CheapReduction.hs +++ b/src/lib/CheapReduction.hs @@ -8,41 +8,35 @@ {-# OPTIONS_GHC -Wno-orphans #-} module CheapReduction - ( CheaplyReducibleE (..), cheapReduce, cheapReduceWithDecls, cheapNormalize - , normalizeProj, asNaryProj, normalizeNaryProj - , depPairLeftTy, instantiateTyConDef - , dataDefRep, unwrapNewtypeType, repValAtom - , unwrapLeadingNewtypesType, wrapNewtypesData, liftSimpAtom, liftSimpType - , liftSimpFun, makeStructRepVal, NonAtomRenamer (..), Visitor (..), VisitGeneric (..) - , visitAtomPartial, visitTypePartial, visitAtomDefault, visitTypeDefault, Visitor2 + ( reduceWithDecls, reduceExpr + , instantiateTyConDef, dataDefRep, unwrapNewtypeType, projType + , NonAtomRenamer (..), Visitor (..), VisitGeneric (..) + , visitAtomDefault, visitTypeDefault, Visitor2, mkStuck, mkStuckTy , visitBinders, visitPiDefault, visitAlt, toAtomVar, instantiate, withInstantiated - , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst) + , bindersToVars, bindersToAtoms, instantiateNames, withInstantiatedNames, assumeConst + , repValAtom, reduceUnwrap, reduceProj, reduceSuperclassProj, typeOfApp + , reduceInstantiateGiven, queryStuckType, substMStuck, reduceTabApp, substStuck + , liftSimpAtom, reduceACase) where import Control.Applicative -import Control.Monad.Trans import Control.Monad.Writer.Strict hiding (Alt) -import Control.Monad.State.Strict -import Control.Monad.Reader -import Data.Foldable (toList) -import Data.Functor.Identity import Data.Functor ((<&>)) -import qualified Data.List.NonEmpty as NE -import qualified Data.Map.Strict as M +import Data.Maybe (fromJust) import Subst import Core import Err import IRVariants -import MTL1 import Name -import PPrint () +import PPrint import QueryTypePure import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Util -import {-# SOURCE #-} Inference (trySynthTerm) +import GHC.Stack -- Carry out the reductions we are willing to carry out during type -- inference. The goal is to support type aliases like `Int = Int32` @@ -54,382 +48,228 @@ import {-# SOURCE #-} Inference (trySynthTerm) -- === api === -type NiceE r e = (HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e, IRRep r) - -cheapReduce :: forall r e' e m n - . (EnvReader m, CheaplyReducibleE r e e', NiceE r e, NiceE r e') - => e n -> m n (Maybe (e' n)) -cheapReduce e = liftCheapReducerM idSubst $ cheapReduceE e -{-# INLINE cheapReduce #-} -{-# SCC cheapReduce #-} - -cheapReduceWithDecls - :: forall r e' e m n l - . ( CheaplyReducibleE r e e', NiceE r e', NiceE r e, EnvReader m ) - => Nest (Decl r) n l -> e l -> m n (Maybe (e' n)) -cheapReduceWithDecls decls result = do - Abs decls' result' <- sinkM $ Abs decls result - liftCheapReducerM idSubst $ - cheapReduceWithDeclsB decls' $ - cheapReduceE result' -{-# INLINE cheapReduceWithDecls #-} -{-# SCC cheapReduceWithDecls #-} - -cheapNormalize :: (EnvReader m, CheaplyReducibleE r e e, NiceE r e) => e n -> m n (e n) -cheapNormalize a = cheapReduce a >>= \case - Just ans -> return ans - _ -> error "couldn't normalize expression" -{-# INLINE cheapNormalize #-} +reduceWithDecls + :: (IRRep r, HasNamesE e, SubstE AtomSubstVal e, EnvReader m) + => WithDecls r e n -> m n (Maybe (e n)) +reduceWithDecls (Abs decls e) = + liftReducerM $ reduceWithDeclsM decls $ substM e --- === internal === +reduceExpr :: (IRRep r, EnvReader m) => Expr r n -> m n (Maybe (Atom r n)) +reduceExpr e = liftReducerM $ reduceExprM e +{-# INLINE reduceExpr #-} -newtype CheapReducerM (r::IR) (i :: S) (o :: S) (a :: *) = - CheapReducerM - (SubstReaderT AtomSubstVal - (MaybeT1 - (ScopedT1 (MapE (AtomName r) (MaybeE (Atom r))) - (EnvReaderT Identity))) i o a) - deriving (Functor, Applicative, Monad, Alternative) - -deriving instance IRRep r => ScopeReader (CheapReducerM r i) -deriving instance IRRep r => EnvReader (CheapReducerM r i) -deriving instance IRRep r => EnvExtender (CheapReducerM r i) -deriving instance IRRep r => SubstReader AtomSubstVal (CheapReducerM r) - -class ( Alternative2 m, SubstReader AtomSubstVal m - , EnvReader2 m, EnvExtender2 m) => CheapReducer m r | m -> r where - updateCache :: AtomName r o -> Maybe (Atom r o) -> m i o () - lookupCache :: AtomName r o -> m i o (Maybe (Maybe (Atom r o))) - -instance IRRep r => CheapReducer (CheapReducerM r) r where - updateCache v u = CheapReducerM $ SubstReaderT $ lift $ lift11 $ - modify (MapE . M.insert v (toMaybeE u) . fromMapE) - lookupCache v = CheapReducerM $ SubstReaderT $ lift $ lift11 $ - fmap fromMaybeE <$> gets (M.lookup v . fromMapE) - -liftCheapReducerM - :: (EnvReader m, IRRep r) - => Subst AtomSubstVal i o -> CheapReducerM r i o a - -> m o (Maybe a) -liftCheapReducerM subst (CheapReducerM m) = do - liftM runIdentity $ liftEnvReaderT $ runScopedT1 - (runMaybeT1 $ runSubstReaderT subst m) mempty -{-# INLINE liftCheapReducerM #-} - -cheapReduceWithDeclsB - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (e o) -cheapReduceWithDeclsB decls cont = do - Abs irreducibleDecls result <- cheapReduceWithDeclsRec decls cont - case hoist irreducibleDecls result of - HoistSuccess result' -> return result' - HoistFailure _ -> empty - -cheapReduceWithDeclsRec - :: NiceE r e - => Nest (Decl r) i i' - -> (forall o'. Ext o o' => CheapReducerM r i' o' (e o')) - -> CheapReducerM r i o (Abs (Nest (Decl r)) e o) -cheapReduceWithDeclsRec decls cont = case decls of - Empty -> Abs Empty <$> cont - Nest (Let b binding@(DeclBinding _ expr)) rest -> do - optional (cheapReduceE expr) >>= \case - Nothing -> do - binding' <- substM binding - withFreshBinder (getNameHint b) binding' \(b':>_) -> do - updateCache (binderName b') Nothing - extendSubst (b@>Rename (binderName b')) do - Abs decls' result <- cheapReduceWithDeclsRec rest cont - return $ Abs (Nest (Let b' binding') decls') result - Just x -> - extendSubst (b@>SubstVal x) $ - cheapReduceWithDeclsRec rest cont - -cheapReduceName :: forall c r i o . (IRRep r, Color c) => Name c o -> CheapReducerM r i o (AtomSubstVal c o) -cheapReduceName v = - case eqColorRep @c @(AtomNameC r) of - Just ColorsEqual -> - lookupEnv v >>= \case - AtomNameBinding binding -> cheapReduceAtomBinding v binding - Nothing -> stuck - where stuck = return $ Rename v - -cheapReduceAtomBinding - :: forall r i o. IRRep r - => AtomName r o -> AtomBinding r o -> CheapReducerM r i o (AtomSubstVal (AtomNameC r) o) -cheapReduceAtomBinding v = \case - LetBound (DeclBinding _ e) -> do - cachedVal <- lookupCache v >>= \case - Nothing -> do - result <- optional (dropSubst $ cheapReduceE e) - updateCache v result - return result - Just result -> return result - case cachedVal of - Nothing -> stuck - Just ans -> return $ SubstVal ans - _ -> stuck - where stuck = return $ Rename v - -class CheaplyReducibleE (r::IR) (e::E) (e'::E) | e -> e', e -> r where - cheapReduceE :: e i -> CheapReducerM r i o (e' o) - -instance IRRep r => CheaplyReducibleE r (Atom r) (Atom r) where - cheapReduceE :: forall i o. Atom r i -> CheapReducerM r i o (Atom r o) - cheapReduceE a = confuseGHC >>= \_ -> case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - Lam _ -> substM a - DictHole ctx ty' access -> do - ty <- cheapReduceE ty' - runFallibleT1 (trySynthTerm ty access) >>= \case - Success d -> return d - Failure _ -> return $ DictHole ctx ty access - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - Con con -> Con <$> traverseOp con cheapReduceE cheapReduceE (error "unexpected lambda") - DictCon t d -> do - t' <- cheapReduceE t - cheapReduceDictExpr t' d - SimpInCore (LiftSimp t x) -> do - t' <- cheapReduceE t - x' <- substM x - liftSimpAtom t' x' - -- These two are a special-case hack. TODO(dougalm): write a traversal over - -- the NewtypeTyCon (or types generally) - NewtypeCon NatCon n -> NewtypeCon NatCon <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -instance IRRep r => CheaplyReducibleE r (Type r) (Type r) where - cheapReduceE :: forall i o. Type r i -> CheapReducerM r i o (Type r o) - cheapReduceE a = case a of - -- Don't try to eagerly reduce lambda bodies. We might get stuck long before - -- we have a chance to apply tham. Also, recursive traversal of those bodies - -- means that we will follow the full call chain, so it's really expensive! - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - TabPi (TabPiType d (b:>t) resultTy) -> do - t' <- cheapReduceE t - d' <- cheapReduceE d - withFreshBinder (getNameHint b) t' \b' -> do - resultTy' <- extendSubst (b@>Rename (binderName b')) $ cheapReduceE resultTy - return $ TabPi $ TabPiType d' b' resultTy' - -- We traverse the Atom constructors that might contain lambda expressions - -- explicitly, to make sure that we can skip normalizing free vars inside those. - NewtypeTyCon (Fin n) -> NewtypeTyCon . Fin <$> cheapReduceE n - -- Do recursive reduction via substitution - -- TODO: we don't collect the dict holes here, so there's a danger of - -- dropping them if they turn out to be phantom. - _ -> do - a' <- substM a - dropSubst $ traverseNames cheapReduceName a' - -cheapReduceDictExpr :: CType o -> DictExpr i -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceDictExpr resultTy d = case d of - SuperclassProj child superclassIx -> do - cheapReduceE child >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName - let InstanceBody superclasses _ = body - instantiate (Abs bs (superclasses !! superclassIx)) args' - child' -> return $ DictCon resultTy $ SuperclassProj child' superclassIx - InstantiatedGiven f xs -> - reduceApp <|> justSubst - where reduceApp = do - f' <- cheapReduceE f - xs' <- mapM cheapReduceE (toList xs) - cheapReduceApp f' xs' - InstanceDict _ _ -> justSubst - IxFin _ -> justSubst - DataData ty -> DictCon resultTy . DataData <$> cheapReduceE ty - where justSubst = DictCon resultTy <$> substM d - -instance CheaplyReducibleE CoreIR TyConParams TyConParams where - cheapReduceE (TyConParams infs ps) = - TyConParams infs <$> mapM cheapReduceE ps - -instance (CheaplyReducibleE r e e', NiceE r e') => CheaplyReducibleE r (Abs (Nest (Decl r)) e) e' where - cheapReduceE (Abs decls result) = cheapReduceWithDeclsB decls $ cheapReduceE result - -instance IRRep r => CheaplyReducibleE r (Expr r) (Atom r) where - cheapReduceE expr = confuseGHC >>= \_ -> case expr of - Atom atom -> cheapReduceE atom - App _ f' xs' -> do - xs <- mapM cheapReduceE xs' - f <- cheapReduceE f' - cheapReduceApp f xs - -- TODO: Make sure that this wraps correctly - -- TODO: Other casts? - PrimOp (MiscOp (CastOp ty' val')) -> do - ty <- cheapReduceE ty' - case ty of - BaseTy (Scalar Word32Type) -> do - val <- cheapReduceE val' - case val of - Con (Lit (Word64Lit v)) -> return $ Con $ Lit $ Word32Lit $ fromIntegral v - _ -> empty - _ -> empty - ApplyMethod _ dict i explicitArgs -> do - explicitArgs' <- mapM cheapReduceE explicitArgs - cheapReduceE dict >>= \case - DictCon _ (InstanceDict instanceName args) -> dropSubst do - args' <- mapM cheapReduceE args - def <- lookupInstanceDef instanceName - withInstantiated def args' \(PairE _ (InstanceBody _ methods)) -> do - method' <- cheapReduceE $ methods !! i - cheapReduceApp method' explicitArgs' - _ -> empty - _ -> empty +-- TODO: just let the caller use `liftReducerM` themselves directly? -cheapReduceApp :: CAtom o -> [CAtom o] -> CheapReducerM CoreIR i o (CAtom o) -cheapReduceApp f xs = case f of - Lam lam -> dropSubst $ withInstantiated lam xs \body -> cheapReduceE body - _ -> empty - -instance IRRep r => CheaplyReducibleE r (IxType r) (IxType r) where - cheapReduceE (IxType t d) = IxType <$> cheapReduceE t <*> cheapReduceE d - -instance IRRep r => CheaplyReducibleE r (IxDict r) (IxDict r) where - cheapReduceE = \case - IxDictAtom x -> IxDictAtom <$> cheapReduceE x - IxDictRawFin n -> IxDictRawFin <$> cheapReduceE n - IxDictSpecialized t d xs -> - IxDictSpecialized <$> cheapReduceE t <*> substM d <*> mapM cheapReduceE xs - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (PairE e1 e2) (PairE e1' e2') where - cheapReduceE (PairE e1 e2) = PairE <$> cheapReduceE e1 <*> cheapReduceE e2 - -instance (CheaplyReducibleE r e1 e1', CheaplyReducibleE r e2 e2') - => CheaplyReducibleE r (EitherE e1 e2) (EitherE e1' e2') where - cheapReduceE (LeftE e) = LeftE <$> cheapReduceE e - cheapReduceE (RightE e) = RightE <$> cheapReduceE e - --- XXX: TODO: figure out exactly what our normalization invariants are. We --- shouldn't have to choose `normalizeProj` or `asNaryProj` on a --- case-by-case basis. This is here for now because it makes it easier to switch --- to the new version of `ProjectElt`. -asNaryProj :: IRRep r => Projection -> Atom r n -> (NE.NonEmpty Projection, AtomVar r n) -asNaryProj p (Var v) = (p NE.:| [], v) -asNaryProj p1 (ProjectElt _ p2 x) = do - let (p2' NE.:| ps, v) = asNaryProj p2 x - (p1 NE.:| (p2':ps), v) -asNaryProj p x = error $ "Can't normalize projection: " ++ pprint p ++ " " ++ pprint x - --- assumes the atom is already normalized -normalizeNaryProj :: IRRep r => EnvReader m => [Projection] -> Atom r n -> m n (Atom r n) -normalizeNaryProj [] x = return x -normalizeNaryProj (i:is) x = normalizeProj i =<< normalizeNaryProj is x - --- assumes the atom itself is already normalized -normalizeProj :: IRRep r => EnvReader m => Projection -> Atom r n -> m n (Atom r n) -normalizeProj UnwrapNewtype atom = case atom of - NewtypeCon _ x -> return x - SimpInCore (LiftSimp (NewtypeTyCon t) x) -> do - t' <- snd <$> unwrapNewtypeType t - return $ SimpInCore $ LiftSimp t' x - x -> case getType x of - NewtypeTyCon t -> do - t' <- snd <$> unwrapNewtypeType t - return $ ProjectElt t' UnwrapNewtype x - _ -> error "expected a newtype" -normalizeProj (ProjectProduct i) atom = case atom of - Con (ProdCon xs) -> return $ xs !! i - DepPair l _ _ | i == 0 -> return l - DepPair _ r _ | i == 1 -> return r - SimpInCore (LiftSimp _ x) -> do - x' <- normalizeProj (ProjectProduct i) x - resultTy <- getResultTy - return $ SimpInCore $ LiftSimp resultTy x' - RepValAtom (RepVal _ tree) -> case tree of - Branch trees -> do - resultTy <- getResultTy - repValAtom $ RepVal resultTy (trees!!i) - Leaf _ -> error "unexpected leaf" - _ -> do - resultTy <- getResultTy - return $ ProjectElt resultTy (ProjectProduct i) atom - where - getResultTy = projType i (getType atom) atom -{-# INLINE normalizeProj #-} +reduceProj :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Atom r n) +reduceProj i x = liftM fromJust $ liftReducerM $ reduceProjM i x +{-# INLINE reduceProj #-} + +reduceACase :: EnvReader m => SAtom n -> [Abs SBinder CAtom n] -> CType n -> m n (CAtom n) +reduceACase scrut alts resultTy = liftM fromJust $ liftReducerM $ reduceACaseM scrut alts resultTy +{-# INLINE reduceACase #-} + +reduceUnwrap :: EnvReader m => CAtom n -> m n (CAtom n) +reduceUnwrap x = liftM fromJust $ liftReducerM $ reduceUnwrapM x +{-# INLINE reduceUnwrap #-} + +reduceSuperclassProj :: EnvReader m => Int -> CDict n -> m n (CAtom n) +reduceSuperclassProj i x = liftM fromJust $ liftReducerM $ reduceSuperclassProjM i x +{-# INLINE reduceSuperclassProj #-} + +reduceInstantiateGiven :: EnvReader m => CAtom n -> [CAtom n] -> m n (CAtom n) +reduceInstantiateGiven f xs = liftM fromJust $ liftReducerM $ reduceInstantiateGivenM f xs +{-# INLINE reduceInstantiateGiven #-} --- === lifting imp to simp and simp to core === +reduceTabApp :: (IRRep r, EnvReader m) => Atom r n -> Atom r n -> m n (Atom r n) +reduceTabApp f x = liftM fromJust $ liftReducerM $ reduceTabAppM f x +{-# INLINE reduceTabApp #-} -repValAtom :: EnvReader m => SRepVal n -> m n (SAtom n) +-- === internal === + +type ReducerM = SubstReaderT AtomSubstVal (EnvReaderT Except) + +liftReducerM :: EnvReader m => ReducerM n n a -> m n (Maybe a) +liftReducerM cont = do + liftM ignoreExcept $ liftEnvReaderT $ runSubstReaderT idSubst do + (Just <$> cont) <|> return Nothing + +reduceWithDeclsM :: IRRep r => Nest (Decl r) i i' -> ReducerM i' o a -> ReducerM i o a +reduceWithDeclsM Empty cont = cont +reduceWithDeclsM (Nest (Let b (DeclBinding _ expr)) rest) cont = do + x <- reduceExprM expr + extendSubst (b@>SubstVal x) $ reduceWithDeclsM rest cont + +reduceExprM :: IRRep r => Expr r i -> ReducerM i o (Atom r o) +reduceExprM = \case + Atom x -> substM x + Block _ (Abs decls result) -> reduceWithDeclsM decls $ reduceExprM result + App _ f xs -> mapM substM xs >>= reduceApp f + Unwrap _ x -> substM x >>= reduceUnwrapM + Project _ i x -> substM x >>= reduceProjM i + ApplyMethod _ dict i explicitArgs -> do + explicitArgs' <- mapM substM explicitArgs + dict' <- substM dict + case dict' of + Con (DictConAtom (InstanceDict _ instanceName args)) -> dropSubst do + def <- lookupInstanceDef instanceName + withInstantiated def args \(PairE _ (InstanceBody _ methods)) -> do + reduceApp (methods !! i) explicitArgs' + _ -> empty + PrimOp (MiscOp (CastOp ty' val')) -> do + ty <- substM ty' + val <- substM val' + case (ty, val) of + (TyCon (BaseType (Scalar Word32Type)), Con (Lit (Word64Lit v))) -> + return $ Con $ Lit $ Word32Lit $ fromIntegral v + _ -> empty + TabApp _ tab x -> do + x' <- substM x + tab' <- substM tab + reduceTabAppM tab' x' + TopApp _ _ _ -> empty + Case _ _ _ -> empty + TabCon _ _ _ -> empty + PrimOp _ -> empty + +reduceApp :: CAtom i -> [CAtom o] -> ReducerM i o (CAtom o) +reduceApp f xs = do + f' <- substM f -- TODO: avoid double-subst + case f' of + Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body + _ -> empty + +reduceACaseM :: SAtom n -> [Abs SBinder CAtom n] -> CType n -> ReducerM i n (CAtom n) +reduceACaseM scrut alts resultTy = case scrut of + Con (SumCon _ i arg) -> do + Abs b body <- return $ alts !! i + applySubst (b@>SubstVal arg) body + Con _ -> error "not a sum type" + Stuck _ scrut' -> mkStuck $ ACase scrut' alts resultTy + +reduceProjM :: IRRep r => Int -> Atom r o -> ReducerM i o (Atom r o) +reduceProjM i x = case x of + Con con -> case con of + ProdCon xs -> return $ xs !! i + DepPair l _ _ | i == 0 -> return l + DepPair _ r _ | i == 1 -> return r + _ -> error "not a product" + Stuck _ e -> mkStuck $ StuckProject i e + +reduceSuperclassProjM :: Int -> CDict o -> ReducerM i o (CAtom o) +reduceSuperclassProjM superclassIx dict = case dict of + DictCon (InstanceDict _ instanceName args) -> dropSubst do + args' <- mapM substM args + InstanceDef _ _ bs _ body <- lookupInstanceDef instanceName + let InstanceBody superclasses _ = body + instantiate (Abs bs (superclasses !! superclassIx)) args' + StuckDict _ child -> mkStuck $ SuperclassProj superclassIx child + _ -> error "invalid superclass projection" + +reduceInstantiateGivenM :: CAtom o -> [CAtom o] -> ReducerM i o (CAtom o) +reduceInstantiateGivenM f xs = case f of + Con (Lam lam) -> dropSubst $ withInstantiated lam xs \body -> reduceExprM body + Stuck _ f' -> mkStuck $ InstantiatedGiven f' xs + _ -> error "bad instantiation" + +mkStuck:: (IRRep r, EnvReader m) => Stuck r n -> m n (Atom r n) +mkStuck x = do + ty <- queryStuckType x + return $ Stuck ty x + +mkStuckTy :: EnvReader m => CStuck n -> m n (CType n) +mkStuckTy x = do + ty <- queryStuckType x + return $ StuckTy ty x + +queryStuckType :: (IRRep r, EnvReader m) => Stuck r n -> m n (Type r n) +queryStuckType = \case + Var v -> return $ getType v + StuckProject i s -> projType i =<< mkStuck s + StuckTabApp f x -> do + fTy <- queryStuckType f + typeOfTabApp fTy x + PtrVar t _ -> return $ PtrTy t + RepValAtom repVal -> return $ getType repVal + StuckUnwrap s -> queryStuckType s >>= \case + TyCon (NewtypeTyCon con) -> snd <$> unwrapNewtypeType con + _ -> error "not a newtype" + InstantiatedGiven f xs -> do + fTy <- queryStuckType f + typeOfApp fTy xs + SuperclassProj i s -> superclassProjType i =<< queryStuckType s + LiftSimp t _ -> return t + LiftSimpFun t _ -> return $ toType t + -- TabLam and ACase are just defunctionalization tools. The result type + -- in both cases should *not* be `Data`. + TabLam (PairE t _) -> return $ toType t + ACase _ _ resultTy -> return resultTy + +projType :: (IRRep r, EnvReader m) => Int -> Atom r n -> m n (Type r n) +projType i x = case getType x of + TyCon con -> case con of + ProdType xs -> return $ xs !! i + DepPairTy t | i == 0 -> return $ depPairLeftTy t + DepPairTy t | i == 1 -> do + liftReducerM (reduceProjM 0 x) >>= \case + Just xFst -> instantiate t [xFst] + Nothing -> err + _ -> err + _ -> err + where err = error $ "Can't project type: " ++ pprint (getType x) + +superclassProjType :: EnvReader m => Int -> CType n -> m n (CType n) +superclassProjType i (TyCon (DictTy dictTy)) = case dictTy of + DictType _ className params -> do + ClassDef _ _ _ _ _ bs superclasses _ <- lookupClassDef className + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params + IxDictType t | i == 0 -> return $ toType $ DataDictType t + _ -> error "bad superclass projection" +superclassProjType _ _ = error "bad superclass projection" + +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> Atom r n -> m n (Type r n) +typeOfTabApp (TyCon (TabPi piTy)) x = withSubstReaderT $ + withInstantiated piTy [x] \ty -> substM ty +typeOfTabApp _ _ = error "expected a TabPi type" + +typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) +typeOfApp (TyCon (Pi piTy)) xs = withSubstReaderT $ + withInstantiated piTy xs \(EffTy _ ty) -> substM ty +typeOfApp _ _ = error "expected a pi type" + +repValAtom :: EnvReader m => RepVal n -> m n (SAtom n) repValAtom (RepVal ty tree) = case ty of - ProdTy ts -> case tree of - Branch trees -> ProdVal <$> mapM repValAtom (zipWith RepVal ts trees) + TyCon (ProdType ts) -> case tree of + Branch trees -> toAtom <$> ProdCon <$> mapM repValAtom (zipWith RepVal ts trees) _ -> malformed - BaseTy _ -> case tree of + TyCon (BaseType _) -> case tree of Leaf x -> case x of - ILit l -> return $ Con $ Lit l + ILit l -> return $ toAtom $ Lit l _ -> fallback _ -> malformed + -- TODO: make sure this covers all the cases. Maybe only TabPi should hit the + -- fallback? This could be a place where we accidentally violate the `Stuck` + -- assumption _ -> fallback - where fallback = return $ RepValAtom $ RepVal ty tree + where fallback = return $ Stuck ty $ RepValAtom $ RepVal ty tree malformed = error "malformed repval" {-# INLINE repValAtom #-} -liftSimpType :: EnvReader m => SType n -> m n (CType n) -liftSimpType = \case - BaseTy t -> return $ BaseTy t - ProdTy ts -> ProdTy <$> mapM rec ts - SumTy ts -> SumTy <$> mapM rec ts - t -> error $ "not implemented: " ++ pprint t - where rec = liftSimpType -{-# INLINE liftSimpType #-} - -liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) -liftSimpAtom ty simpAtom = case simpAtom of - Var _ -> justLift - ProjectElt _ _ _ -> justLift - RepValAtom _ -> justLift -- TODO(dougalm): should we make more effort to pull out products etc? - _ -> do - (cons , ty') <- unwrapLeadingNewtypesType ty - atom <- case (ty', simpAtom) of - (BaseTy _ , Con (Lit v)) -> return $ Con $ Lit v - (ProdTy tys, Con (ProdCon xs)) -> Con . ProdCon <$> zipWithM rec tys xs - (SumTy tys, Con (SumCon _ i x)) -> Con . SumCon tys i <$> rec (tys!!i) x - (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do - x1' <- rec t1 x1 - t2' <- applySubst (b@>SubstVal x1') t2 - x2' <- rec t2' x2 - return $ DepPair x1' x2' dpt - _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty' - return $ wrapNewtypesData cons atom - where - rec = liftSimpAtom - justLift = return $ SimpInCore $ LiftSimp ty simpAtom -{-# INLINE liftSimpAtom #-} - -liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) -liftSimpFun (Pi piTy) f = return $ SimpInCore $ LiftSimpFun piTy f -liftSimpFun _ _ = error "not a pi type" - --- See Note [Confuse GHC] from Simplify.hs -confuseGHC :: IRRep r => CheapReducerM r i n (DistinctEvidence n) -confuseGHC = getDistinct -{-# INLINE confuseGHC #-} - --- TODO: These used to live in QueryType. Think about a better way to organize --- them. Maybe a common set of low-level type-querying utils that both --- CheapReduction and QueryType import? - depPairLeftTy :: DepPairType r n -> Type r n depPairLeftTy (DepPairType _ (_:>ty) _) = ty {-# INLINE depPairLeftTy #-} +reduceUnwrapM :: CAtom o -> ReducerM i o (CAtom o) +reduceUnwrapM = \case + Con con -> case con of + NewtypeCon _ x -> return x + _ -> error "not a newtype" + Stuck _ e -> mkStuck $ StuckUnwrap e + +reduceTabAppM :: IRRep r => Atom r o -> Atom r o -> ReducerM i o (Atom r o) +reduceTabAppM tab x = case tab of + Stuck _ tab' -> mkStuck (StuckTabApp tab' x) + _ -> error $ "not a table" ++ pprint tab + unwrapNewtypeType :: EnvReader m => NewtypeTyCon n -> m n (NewtypeCon n, Type CoreIR n) unwrapNewtypeType = \case Nat -> return (NatCon, IdxRepTy) @@ -441,27 +281,6 @@ unwrapNewtypeType = \case ty -> error $ "Shouldn't be projecting: " ++ pprint ty {-# INLINE unwrapNewtypeType #-} -projType :: (IRRep r, EnvReader m) => Int -> Type r n -> Atom r n -> m n (Type r n) -projType i ty x = case ty of - ProdTy xs -> return $ xs !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x - instantiate t [xFst] - _ -> error $ "Can't project type: " ++ pprint ty - -unwrapLeadingNewtypesType :: EnvReader m => CType n -> m n ([NewtypeCon n], CType n) -unwrapLeadingNewtypesType = \case - NewtypeTyCon tyCon -> do - (dataCon, ty) <- unwrapNewtypeType tyCon - (dataCons, ty') <- unwrapLeadingNewtypesType ty - return (dataCon:dataCons, ty') - ty -> return ([], ty) - -wrapNewtypesData :: [NewtypeCon n] -> CAtom n-> CAtom n -wrapNewtypesData [] x = x -wrapNewtypesData (c:cs) x = NewtypeCon c $ wrapNewtypesData cs x - instantiateTyConDef :: EnvReader m => TyConDef n -> TyConParams n -> m n (DataConDefs n) instantiateTyConDef (TyConDef _ _ bs conDefs) (TyConParams _ xs) = do applySubst (bs @@> (SubstVal <$> xs)) conDefs @@ -507,19 +326,10 @@ dataDefRep :: DataConDefs n -> CType n dataDefRep (ADTCons cons) = case cons of [] -> error "unreachable" -- There's no representation for a void type [DataConDef _ _ ty _] -> ty - tys -> SumTy $ tys <&> \(DataConDef _ _ ty _) -> ty + tys -> toType $ SumType $ tys <&> \(DataConDef _ _ ty _) -> ty dataDefRep (StructFields fields) = case map snd fields of [ty] -> ty - tys -> ProdTy tys - -makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) -makeStructRepVal tyConName args = do - TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName - case fields of - [_] -> case args of - [arg] -> return arg - _ -> error "wrong number of args" - _ -> return $ ProdVal args + tys -> toType (ProdType tys) -- === traversable terms === @@ -533,7 +343,7 @@ class NonAtomRenamer m i o => Visitor m r i o | m -> i, m -> o where visitPi :: PiType r i -> m (PiType r o) class VisitGeneric (e:: E) (r::IR) | e -> r where - visitGeneric :: Visitor m r i o => e i -> m (e o) + visitGeneric :: HasCallStack => Visitor m r i o => e i -> m (e o) type Visitor2 (m::MonadKind2) r = forall i o . Visitor (m i o) r i o @@ -542,7 +352,7 @@ instance VisitGeneric (Type r) r where visitGeneric = visitType instance VisitGeneric (LamExpr r) r where visitGeneric = visitLam instance VisitGeneric (PiType r) r where visitGeneric = visitPi -visitBlock :: Visitor m r i o => Block r i -> m (Block r o) +visitBlock :: Visitor m r i o => Expr r i -> m (Expr r o) visitBlock b = visitGeneric (LamExpr Empty b) >>= \case LamExpr Empty b' -> return b' _ -> error "not a block" @@ -558,22 +368,19 @@ traverseOpTerm => e r i -> m (e r o) traverseOpTerm e = traverseOp e visitGeneric visitGeneric visitGeneric -visitAtomDefault - :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) - => Atom r i -> m i o (Atom r o) -visitAtomDefault atom = case atom of - Var _ -> atomSubstM atom - SimpInCore _ -> atomSubstM atom - ProjectElt t i x -> ProjectElt <$> visitType t <*> pure i <*> visitGeneric x - _ -> visitAtomPartial atom - visitTypeDefault :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) => Type r i -> m i o (Type r o) -visitTypeDefault = \case - TyVar v -> atomSubstM $ TyVar v - ProjectEltTy t i x -> ProjectEltTy <$> visitType t <*> pure i <*> visitGeneric x - x -> visitTypePartial x +visitTypeDefault ty = case ty of + StuckTy _ _ -> atomSubstM ty + TyCon con -> TyCon <$> visitGeneric con + +visitAtomDefault + :: (IRRep r, Visitor (m i o) r i o, AtomSubstReader v m, EnvReader2 m) + => Atom r i -> m i o (Atom r o) +visitAtomDefault ty = case ty of + Stuck _ _ -> atomSubstM ty + Con con -> Con <$> visitGeneric con visitPiDefault :: (Visitor2 m r, IRRep r, FromName v, AtomSubstReader v m, EnvExtender2 m) @@ -596,47 +403,11 @@ visitBinders (Nest (b:>ty) bs) cont = do visitBinders bs \bs' -> cont $ Nest b' bs' --- XXX: This doesn't handle the `Var`, `ProjectElt`, `SimpInCore` cases. These --- should be handled explicitly beforehand. TODO: split out these cases under a --- separate constructor, perhaps even a `hole` paremeter to `Atom` or part of --- `IR`. -visitAtomPartial :: (IRRep r, Visitor m r i o) => Atom r i -> m (Atom r o) -visitAtomPartial = \case - Var _ -> error "Not handled generically" - SimpInCore _ -> error "Not handled generically" - ProjectElt _ _ _ -> error "Not handled generically" - Con con -> Con <$> visitGeneric con - PtrVar t v -> PtrVar t <$> renameN v - DepPair x y t -> do - x' <- visitGeneric x - y' <- visitGeneric y - ~(DepPairTy t') <- visitGeneric $ DepPairTy t - return $ DepPair x' y' t' - Lam lam -> Lam <$> visitGeneric lam - Eff eff -> Eff <$> visitGeneric eff - DictCon t d -> DictCon <$> visitType t <*> visitGeneric d - NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x - DictHole ctx ty access -> DictHole ctx <$> visitGeneric ty <*> pure access - TypeAsAtom t -> TypeAsAtom <$> visitGeneric t - RepValAtom repVal -> RepValAtom <$> visitGeneric repVal - --- XXX: This doesn't handle the `TyVar` or `ProjectEltTy` cases. These should be --- handled explicitly beforehand. -visitTypePartial :: (IRRep r, Visitor m r i o) => Type r i -> m (Type r o) -visitTypePartial = \case - TyVar _ -> error "Not handled generically" - ProjectEltTy _ _ _ -> error "Not handled generically" - NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t - Pi t -> Pi <$> visitGeneric t - TabPi t -> TabPi <$> visitGeneric t - TC t -> TC <$> visitGeneric t - DepPairTy t -> DepPairTy <$> visitGeneric t - DictTy t -> DictTy <$> visitGeneric t - instance IRRep r => VisitGeneric (Expr r) r where visitGeneric = \case + Block _ _ -> error "not handled generically" TopApp et v xs -> TopApp <$> visitGeneric et <*> renameN v <*> mapM visitGeneric xs - TabApp t tab xs -> TabApp <$> visitType t <*> visitGeneric tab <*> mapM visitGeneric xs + TabApp t tab x -> TabApp <$> visitType t <*> visitGeneric tab <*> visitGeneric x -- TODO: should we reuse the original effects? Whether it's valid depends on -- the type-preservation requirements for a visitor. We should clarify what -- those are. @@ -651,6 +422,8 @@ instance IRRep r => VisitGeneric (Expr r) r where PrimOp op -> PrimOp <$> visitGeneric op App et fAtom xs -> App <$> visitGeneric et <*> visitGeneric fAtom <*> mapM visitGeneric xs ApplyMethod et m i xs -> ApplyMethod <$> visitGeneric et <*> visitGeneric m <*> pure i <*> mapM visitGeneric xs + Project t i x -> Project <$> visitGeneric t <*> pure i <*> visitGeneric x + Unwrap t x -> Unwrap <$> visitGeneric t <*> visitGeneric x instance IRRep r => VisitGeneric (PrimOp r) r where visitGeneric = \case @@ -702,19 +475,36 @@ instance IRRep r => VisitGeneric (EffectRow r) r where effs' <- eSetFromList <$> mapM visitGeneric (eSetToList effs) tailEffRow <- case tailVar of NoTail -> return $ EffectRow mempty NoTail - EffectRowTail v -> visitGeneric (Var v) <&> \case - Var v' -> EffectRow mempty (EffectRowTail v') - Eff r -> r + EffectRowTail v -> visitGeneric (toAtom v) <&> \case + Stuck _ (Var v') -> EffectRow mempty (EffectRowTail v') + Con (Eff r) -> r _ -> error "Not a valid effect substitution" return $ extendEffRow effs' tailEffRow -instance VisitGeneric DictExpr CoreIR where +instance IRRep r => VisitGeneric (DictCon r) r where + visitGeneric = \case + InstanceDict t v xs -> InstanceDict <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs + IxFin x -> IxFin <$> visitGeneric x + DataData dataTy -> DataData <$> visitGeneric dataTy + IxRawFin x -> IxRawFin <$> visitGeneric x + IxSpecialized v xs -> IxSpecialized <$> renameN v <*> mapM visitGeneric xs + +instance IRRep r => VisitGeneric (Con r) r where visitGeneric = \case - InstantiatedGiven x xs -> InstantiatedGiven <$> visitGeneric x <*> mapM visitGeneric xs - SuperclassProj x i -> SuperclassProj <$> visitGeneric x <*> pure i - InstanceDict v xs -> InstanceDict <$> renameN v <*> mapM visitGeneric xs - IxFin x -> IxFin <$> visitGeneric x - DataData t -> DataData <$> visitGeneric t + Lit l -> return $ Lit l + ProdCon xs -> ProdCon <$> mapM visitGeneric xs + SumCon ty con arg -> SumCon <$> mapM visitGeneric ty <*> return con <*> visitGeneric arg + HeapVal -> return HeapVal + DepPair x y t -> do + x' <- visitGeneric x + y' <- visitGeneric y + ~(DepPairTy t') <- visitGeneric $ DepPairTy t + return $ DepPair x' y' t' + Lam lam -> Lam <$> visitGeneric lam + Eff eff -> Eff <$> visitGeneric eff + DictConAtom d -> DictConAtom <$> visitGeneric d + TyConAtom t -> TyConAtom <$> visitGeneric t + NewtypeCon con x -> NewtypeCon <$> visitGeneric con <*> visitGeneric x instance VisitGeneric NewtypeCon CoreIR where visitGeneric = \case @@ -732,17 +522,15 @@ instance VisitGeneric NewtypeTyCon CoreIR where instance VisitGeneric TyConParams CoreIR where visitGeneric (TyConParams expls xs) = TyConParams expls <$> mapM visitGeneric xs -instance IRRep r => VisitGeneric (IxDict r) r where - visitGeneric = \case - IxDictAtom x -> IxDictAtom <$> visitGeneric x - IxDictRawFin x -> IxDictRawFin <$> visitGeneric x - IxDictSpecialized t v xs -> IxDictSpecialized <$> visitGeneric t <*> renameN v <*> mapM visitGeneric xs instance IRRep r => VisitGeneric (IxType r) r where visitGeneric (IxType t d) = IxType <$> visitType t <*> visitGeneric d instance VisitGeneric DictType CoreIR where - visitGeneric (DictType n v xs) = DictType n <$> renameN v <*> mapM visitGeneric xs + visitGeneric = \case + DictType n v xs -> DictType n <$> renameN v <*> mapM visitGeneric xs + IxDictType t -> IxDictType <$> visitGeneric t + DataDictType t -> DataDictType <$> visitGeneric t instance VisitGeneric CoreLamExpr CoreIR where visitGeneric (CoreLamExpr t lam) = CoreLamExpr <$> visitGeneric t <*> visitGeneric lam @@ -765,7 +553,7 @@ instance IRRep r => VisitGeneric (DepPairType r) r where PiType (UnaryNest b') (EffTy Pure ty') -> DepPairType expl b' ty' _ -> error "not a dependent pair type" -instance VisitGeneric (RepVal SimpIR) SimpIR where +instance VisitGeneric RepVal SimpIR where visitGeneric (RepVal ty tree) = RepVal <$> visitGeneric ty <*> mapM renameIExpr tree where renameIExpr = \case ILit l -> return $ ILit l @@ -793,8 +581,25 @@ instance VisitGeneric DataConDef CoreIR where repTy' <- visitGeneric repTy return $ DataConDef sn (Abs bs' UnitE) repTy' ps -instance VisitGeneric (Con r) r where visitGeneric = traverseOpTerm -instance VisitGeneric (TC r) r where visitGeneric = traverseOpTerm +instance IRRep r => VisitGeneric (TyCon r) r where + visitGeneric = \case + BaseType bt -> return $ BaseType bt + ProdType tys -> ProdType <$> mapM visitGeneric tys + SumType tys -> SumType <$> mapM visitGeneric tys + RefType h t -> RefType <$> visitGeneric h <*> visitGeneric t + HeapType -> return HeapType + TabPi t -> TabPi <$> visitGeneric t + DepPairTy t -> DepPairTy <$> visitGeneric t + TypeKind -> return TypeKind + DictTy t -> DictTy <$> visitGeneric t + Pi t -> Pi <$> visitGeneric t + NewtypeTyCon t -> NewtypeTyCon <$> visitGeneric t + +instance IRRep r => VisitGeneric (Dict r) r where + visitGeneric = \case + StuckDict ty s -> fromJust <$> toMaybeDict <$> visitGeneric (Stuck ty s) + DictCon con -> DictCon <$> visitGeneric con + instance VisitGeneric (MiscOp r) r where visitGeneric = traverseOpTerm instance VisitGeneric (VectorOp r) r where visitGeneric = traverseOpTerm instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm @@ -802,11 +607,6 @@ instance VisitGeneric (MemOp r) r where visitGeneric = traverseOpTerm -- === SubstE/SubstB instances === -- These live here, as orphan instances, because we normalize as we substitute. -toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) -toAtomVar v = do - ty <- getType <$> lookupAtomName v - return $ AtomVar v ty - bindersToVars :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [AtomVar r n] bindersToVars bs = do withExtEvidence bs do @@ -814,22 +614,7 @@ bindersToVars bs = do mapM toAtomVar $ nestToNames bs bindersToAtoms :: (EnvReader m, IRRep r) => Nest (Binder r) n' n -> m n [Atom r n] -bindersToAtoms bs = liftM (Var <$>) $ bindersToVars bs - -newtype SubstVisitor i o a = SubstVisitor { runSubstVisitor :: Reader (Env o, Subst AtomSubstVal i o) a } - deriving (Functor, Applicative, Monad, MonadReader (Env o, Subst AtomSubstVal i o)) - -substV :: (Distinct o, SubstE AtomSubstVal e) => e i -> SubstVisitor i o (e o) -substV x = ask <&> \env -> substE env x - -instance Distinct o => NonAtomRenamer (SubstVisitor i o) i o where - renameN = substV - -instance (Distinct o, IRRep r) => Visitor (SubstVisitor i o) r i o where - visitType = substV - visitAtom = substV - visitLam = substV - visitPi = substV +bindersToAtoms bs = liftM (toAtom <$>) $ bindersToVars bs instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where substE (_, env) (Rename name) = env ! name @@ -837,31 +622,89 @@ instance Color c => SubstE AtomSubstVal (AtomSubstVal c) where instance SubstV (SubstVal Atom) (SubstVal Atom) where +instance IRRep r => SubstE AtomSubstVal (IxDict r) where + substE es = \case + StuckDict _ e -> fromJust $ toMaybeDict $ substStuck es e + DictCon con -> DictCon $ substE es con + instance IRRep r => SubstE AtomSubstVal (Atom r) where - substE es@(env, subst) = \case - Var (AtomVar v ty) -> case subst!v of - Rename v' -> Var $ AtomVar v' (substE es ty) - SubstVal x -> x - SimpInCore x -> SimpInCore (substE es x) - ProjectElt _ i x -> do - let x' = substE es x - runEnvReaderM env $ normalizeProj i x' - atom -> runReader (runSubstVisitor $ visitAtomPartial atom) es + substE es = \case + Stuck _ e -> substStuck es e + Con con -> Con $ substE es con instance IRRep r => SubstE AtomSubstVal (Type r) where - substE es@(env, subst) = \case - TyVar (AtomVar v ty) -> case subst ! v of - Rename v' -> TyVar $ AtomVar v' (substE es ty) - SubstVal (Type t) -> t - SubstVal atom -> error $ "bad substitution: " ++ pprint v ++ " -> " ++ pprint atom - ProjectEltTy _ i x -> do - let x' = substE es x - case runEnvReaderM env $ normalizeProj i x' of - Type t -> t - _ -> error "bad substitution" - ty -> runReader (runSubstVisitor $ visitTypePartial ty) es - -instance SubstE AtomSubstVal SimpInCore + substE es = \case + StuckTy _ e -> fromJust $ toMaybeType $ substStuck es e + TyCon con -> TyCon $ substE es con + +substMStuck :: (SubstReader AtomSubstVal m, EnvReader2 m, IRRep r) => Stuck r i -> m i o (Atom r o) +substMStuck stuck = do + subst <- getSubst + env <- unsafeGetEnv + withDistinct $ return $ substStuck (env, subst) stuck + +substStuck :: (IRRep r, Distinct o) => (Env o, Subst AtomSubstVal i o) -> Stuck r i -> Atom r o +substStuck (env, subst) stuck = + ignoreExcept $ runEnvReaderT env $ runSubstReaderT subst $ reduceStuck stuck + +reduceStuck :: (IRRep r, Distinct o) => Stuck r i -> ReducerM i o (Atom r o) +reduceStuck = \case + Var (AtomVar v ty) -> do + lookupSubstM v >>= \case + Rename v' -> toAtom . AtomVar v' <$> substM ty + SubstVal x -> return x + StuckProject i x -> do + x' <- reduceStuck x + dropSubst $ reduceProjM i x' + StuckUnwrap x -> do + x' <- reduceStuck x + dropSubst $ reduceUnwrapM x' + StuckTabApp f x -> do + f' <- reduceStuck f + x' <- substM x + dropSubst $ reduceTabAppM f' x' + InstantiatedGiven f xs -> do + xs' <- mapM substM xs + f' <- reduceStuck f + reduceInstantiateGivenM f' xs' + SuperclassProj superclassIx child -> do + Just child' <- toMaybeDict <$> reduceStuck child + reduceSuperclassProjM superclassIx child' + PtrVar ptrTy ptr -> mkStuck =<< PtrVar ptrTy <$> substM ptr + RepValAtom repVal -> mkStuck =<< RepValAtom <$> substM repVal + LiftSimp t s -> do + t' <- substM t + s' <- reduceStuck s + liftSimpAtom t' s' + LiftSimpFun t f -> mkStuck =<< (LiftSimpFun <$> substM t <*> substM f) + TabLam lam -> mkStuck =<< (TabLam <$> substM lam) + ACase scrut alts resultTy -> do + scrut' <- reduceStuck scrut + resultTy' <- substM resultTy + alts' <- mapM substM alts + reduceACaseM scrut' alts' resultTy' + +liftSimpAtom :: EnvReader m => Type CoreIR n -> SAtom n -> m n (CAtom n) +liftSimpAtom (StuckTy _ _) _ = error "Can't lift stuck type" +liftSimpAtom ty@(TyCon tyCon) simpAtom = case simpAtom of + Stuck _ stuck -> return $ Stuck ty $ LiftSimp ty stuck + Con con -> Con <$> case (tyCon, con) of + (NewtypeTyCon newtypeCon, _) -> do + (dataCon, repTy) <- unwrapNewtypeType newtypeCon + cAtom <- rec repTy (Con con) + return $ NewtypeCon dataCon cAtom + (BaseType _ , Lit v) -> return $ Lit v + (ProdType tys, ProdCon xs) -> ProdCon <$> zipWithM rec tys xs + (SumType tys, SumCon _ i x) -> SumCon tys i <$> rec (tys!!i) x + (DepPairTy dpt@(DepPairType _ (b:>t1) t2), DepPair x1 x2 _) -> do + x1' <- rec t1 x1 + t2' <- applySubst (b@>SubstVal x1') t2 + x2' <- rec t2' x2 + return $ DepPair x1' x2' dpt + _ -> error $ "can't lift " <> pprint simpAtom <> " to " <> pprint ty + where + rec = liftSimpAtom +{-# INLINE liftSimpAtom #-} instance IRRep r => SubstE AtomSubstVal (EffectRow r) where substE env (EffectRow effs tailVar) = do @@ -872,8 +715,8 @@ instance IRRep r => SubstE AtomSubstVal (EffectRow r) where Rename v' -> do let v'' = runEnvReaderM (fst env) $ toAtomVar v' EffectRow mempty (EffectRowTail v'') - SubstVal (Var v') -> EffectRow mempty (EffectRowTail v') - SubstVal (Eff r) -> r + SubstVal (Stuck _ (Var v')) -> EffectRow mempty (EffectRowTail v') + SubstVal (Con (Eff r)) -> r _ -> error "Not a valid effect substitution" extendEffRow effs' tailEffRow @@ -883,21 +726,22 @@ instance SubstE AtomSubstVal SpecializationSpec where substE env (AppSpecialization (AtomVar f _) ab) = do let f' = case snd env ! f of Rename v -> runEnvReaderM (fst env) $ toAtomVar v - SubstVal (Var v) -> v + SubstVal (Stuck _ (Var v)) -> v _ -> error "bad substitution" AppSpecialization f' (substE env ab) instance SubstE AtomSubstVal EffectDef instance SubstE AtomSubstVal EffectOpType instance SubstE AtomSubstVal IExpr -instance IRRep r => SubstE AtomSubstVal (RepVal r) +instance SubstE AtomSubstVal RepVal instance SubstE AtomSubstVal TyConParams instance SubstE AtomSubstVal DataConDef instance IRRep r => SubstE AtomSubstVal (BaseMonoid r) instance IRRep r => SubstE AtomSubstVal (DAMOp r) instance IRRep r => SubstE AtomSubstVal (TypedHof r) instance IRRep r => SubstE AtomSubstVal (Hof r) -instance IRRep r => SubstE AtomSubstVal (TC r) +instance IRRep r => SubstE AtomSubstVal (TyCon r) +instance IRRep r => SubstE AtomSubstVal (DictCon r) instance IRRep r => SubstE AtomSubstVal (Con r) instance IRRep r => SubstE AtomSubstVal (MiscOp r) instance IRRep r => SubstE AtomSubstVal (VectorOp r) @@ -909,7 +753,6 @@ instance IRRep r => SubstE AtomSubstVal (Expr r) instance IRRep r => SubstE AtomSubstVal (GenericOpRep const r) instance SubstE AtomSubstVal InstanceBody instance SubstE AtomSubstVal DictType -instance SubstE AtomSubstVal DictExpr instance IRRep r => SubstE AtomSubstVal (LamExpr r) instance SubstE AtomSubstVal CorePiType instance SubstE AtomSubstVal CoreLamExpr @@ -921,6 +764,5 @@ instance IRRep r => SubstE AtomSubstVal (DeclBinding r) instance IRRep r => SubstB AtomSubstVal (Decl r) instance SubstE AtomSubstVal NewtypeTyCon instance SubstE AtomSubstVal NewtypeCon -instance IRRep r => SubstE AtomSubstVal (IxDict r) instance IRRep r => SubstE AtomSubstVal (IxType r) instance SubstE AtomSubstVal DataConDefs diff --git a/src/lib/CheckType.hs b/src/lib/CheckType.hs index b8677c185..580039a84 100644 --- a/src/lib/CheckType.hs +++ b/src/lib/CheckType.hs @@ -6,7 +6,7 @@ {-# LANGUAGE UndecidableInstances #-} -module CheckType (CheckableE (..), CheckableB (..), checkBlock, checkTypes, checkTypeIs) where +module CheckType (CheckableE (..), CheckableB (..), checkTypes, checkTypeIs) where import Prelude hiding (id) import Control.Category ((>>>)) @@ -22,11 +22,12 @@ import IRVariants import MTL1 import Name import Subst -import PPrint () +import PPrint import QueryType import Types.Core import Types.Primitives import Types.Source +import Types.Top -- === top-level API === @@ -39,16 +40,12 @@ checkTypeIs e ty = liftTyperM (void $ e |: ty) >>= liftExcept -- === the type checking/querying monad === newtype TyperM (r::IR) (i::S) (o::S) (a :: *) = - TyperM { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) FallibleEnvReaderM) i o a } + TyperM { runTyperT' :: SubstReaderT Name (StateT1 (NameMap (AtomNameC r) Int) (EnvReaderT Except)) i o a } deriving ( Functor, Applicative, Monad , SubstReader Name , MonadFail , Fallible , ScopeReader - , EnvReader, EnvExtender) + , EnvReader, EnvExtender, Catchable) liftTyperM :: EnvReader m => TyperM r n n a -> m n (Except a) -liftTyperM cont = - liftM runFallibleM $ liftEnvReaderT $ - flip evalStateT1 mempty $ - runSubstReaderT idSubst $ - runTyperT' cont +liftTyperM cont = liftEnvReaderT $ flip evalStateT1 mempty $ runSubstReaderT idSubst $ runTyperT' cont {-# INLINE liftTyperM #-} -- I can't make up my mind whether a `Seq` loop should be allowed to @@ -59,7 +56,7 @@ affineUsed name = TyperM $ do case lookupNameMapE name affines of Just (LiftE n) -> if n > 0 then - throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." + throwInternal $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times." else put $ insertNameMapE name (LiftE $ n + 1) affines Nothing -> put $ insertNameMapE name (LiftE 1) affines @@ -91,11 +88,9 @@ checkTypesEq :: IRRep r => Type r o -> Type r o -> TyperM r i o () checkTypesEq reqTy ty = alphaEq reqTy ty >>= \case True -> return () False -> {-# SCC typeNormalization #-} do - reqTy' <- cheapNormalize reqTy - ty' <- cheapNormalize ty - alphaEq reqTy' ty' >>= \case + alphaEq reqTy ty >>= \case True -> return () - False -> throw TypeErr $ pprint reqTy' ++ " != " ++ pprint ty' + False -> throwInternal $ pprint reqTy ++ " != " ++ pprint ty {-# INLINE checkTypesEq #-} class SinkableE e => CheckableE (r::IR) (e::E) | e -> r where @@ -129,6 +124,12 @@ checkAndGetType x = do x' <- checkE x return (x', getType x') +checkWithEffTy :: (CheckableWithEffects r e, HasType r e, IRRep r) => EffTy r o -> e i -> TyperM r i o (e o) +checkWithEffTy (EffTy effs ty) e = do + e' <- checkWithEffects effs e + checkTypesEq ty (getType e') + return e' + instance CheckableE CoreIR SourceMap where checkE sm = renameM sm -- TODO? @@ -144,6 +145,14 @@ instance (CheckableB r b, CheckableE r e) => CheckableE r (Abs b e) where -- === type checking core === +checkStuck :: IRRep r => Type r i -> Stuck r i -> TyperM r i o (Type r o, Stuck r o) +checkStuck ty e = do + e' <- checkE e + ty' <- checkE ty + ty'' <- queryStuckType e' + checkTypesEq ty' ty'' + return (ty', e') + instance IRRep r => CheckableE r (TopLam r) where checkE (TopLam destFlag piTy lam) = do -- TODO: check destination-passing flag @@ -156,52 +165,8 @@ instance IRRep r => CheckableE r (AtomName r) where instance IRRep r => CheckableE r (Atom r) where checkE = \case - Var name -> do - name' <- checkE name - case getType name' of - RawRefTy _ -> affineUsed $ atomVarName name' - _ -> return () - return $ Var name' - Lam lam -> Lam <$> checkE lam - DepPair l r ty -> do - l' <- checkE l - ty' <- checkE ty - rTy <- checkInstantiation ty' [l'] - r' <- r |: rTy - return $ DepPair l' r' ty' - Con con -> Con <$> checkE con - Eff eff -> Eff <$> checkE eff - PtrVar t v -> PtrVar t <$> renameM v - -- TODO: check against cached type - DictCon ty dictExpr -> DictCon <$> checkE ty <*> checkE dictExpr - RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check - NewtypeCon con x -> do - (x', xTy) <- checkAndGetType x - con' <- typeCheckNewtypeCon con xTy - return $ NewtypeCon con' x' - SimpInCore x -> SimpInCore <$> checkE x - DictHole ctx ty access -> do - ty' <- ty |: TyKind - return $ DictHole ctx ty' access - ProjectElt resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - (x', NewtypeTyCon con) <- checkAndGetType x - resultTy'' <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' UnwrapNewtype x' - ProjectElt resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', xTy) <- checkAndGetType x - resultTy'' <- case xTy of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - checkInstantiation t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint xTy - checkTypesEq resultTy' resultTy'' - return $ ProjectElt resultTy' (ProjectProduct i) x' - TypeAsAtom ty -> TypeAsAtom <$> checkE ty + Stuck ty e -> uncurry Stuck <$> checkStuck ty e + Con e -> Con <$> checkE e instance IRRep r => CheckableE r (AtomVar r) where checkE (AtomVar v t1) = do @@ -213,40 +178,8 @@ instance IRRep r => CheckableE r (AtomVar r) where instance IRRep r => CheckableE r (Type r) where checkE = \case - Pi t -> Pi <$> checkE t - TabPi t -> TabPi <$> checkE t - NewtypeTyCon t -> NewtypeTyCon <$> checkE t - TC t -> TC <$> checkE t - DepPairTy t -> DepPairTy <$> checkE t - DictTy (DictType sn className params) -> do - className' <- renameM className - ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className' - params' <- mapM checkE params - void $ checkInstantiation (Abs paramBs UnitE) params' - return $ DictTy (DictType sn className' params') - TyVar v -> TyVar <$> checkE v - ProjectEltTy resultTy UnwrapNewtype x -> do - resultTy' <- resultTy |: TyKind - x' <- checkE x - NewtypeTyCon con <- return $ getType x' - ty <- snd <$> unwrapNewtypeType con - checkTypesEq resultTy' ty - return $ ProjectEltTy resultTy' UnwrapNewtype x' - ProjectEltTy resultTy (ProjectProduct i) x -> do - resultTy' <- resultTy |: TyKind - (x', ty) <- checkAndGetType x - resultTy'' <- case ty of - ProdTy tys -> return $ tys !! i - DepPairTy t | i == 0 -> return $ depPairLeftTy t - DepPairTy t | i == 1 -> do - xFst <- normalizeProj (ProjectProduct 0) x' - instantiate t [xFst] - _ -> throw TypeErr $ "Not a product type:" ++ pprint ty - checkTypesEq resultTy' resultTy'' - return $ ProjectEltTy resultTy' (ProjectProduct i) x' - -instance CheckableE CoreIR SimpInCore where - checkE x = renameM x -- TODO: check + StuckTy ty e -> uncurry StuckTy <$> checkStuck ty e + TyCon e -> TyCon <$> checkE e instance (ToBinding ann c, Color c, CheckableE r ann) => CheckableB r (BinderP c ann) where checkB (b:>ann) cont = do @@ -265,22 +198,22 @@ checkBinderType ty b cont = do cont b' instance IRRep r => CheckableWithEffects r (Expr r) where - checkWithEffects allowedEffs expr = addContext ("Checking expr:\n" ++ pprint expr) case expr of + checkWithEffects allowedEffs expr = case expr of App effTy f xs -> do effTy' <- checkEffTy allowedEffs effTy f' <- checkE f - Pi piTy <- return $ getType f' + TyCon (Pi piTy) <- return $ getType f' xs' <- mapM checkE xs effTy'' <- checkInstantiation piTy xs' checkAlphaEq effTy' effTy'' return $ App effTy' f' xs' - TabApp reqTy f xs -> do - reqTy' <- reqTy |: TyKind + TabApp reqTy f x -> do + reqTy' <- checkE reqTy (f', tabTy) <- checkAndGetType f - xs' <- mapM checkE xs - ty' <- checkTabApp tabTy xs' + x' <- checkE x + ty' <- checkTabApp tabTy x' checkTypesEq reqTy' ty' - return $ TabApp reqTy' f' xs' + return $ TabApp reqTy' f' x' TopApp effTy f xs -> do f' <- renameM f effTy' <- checkEffTy allowedEffs effTy @@ -291,26 +224,32 @@ instance IRRep r => CheckableWithEffects r (Expr r) where return $ TopApp effTy' f' xs' Atom x -> Atom <$> checkE x PrimOp op -> PrimOp <$> checkWithEffects allowedEffs op + Block effTy (Abs decls body) -> do + effTy'@(EffTy effs ty) <- checkEffTy allowedEffs effTy + checkDecls effs decls \decls' -> do + body' <- checkWithEffects (sink effs) body + checkTypesEq (sink ty) (getType body') + return $ Block effTy' $ Abs decls' body' Case scrut alts effTy -> do effTy' <- checkEffTy allowedEffs effTy scrut' <- checkE scrut - altsBinderTys <- checkCaseAltsBinderTys $ getType scrut' + TyCon (SumType altsBinderTys) <- return $ getType scrut' assertEq (length altsBinderTys) (length alts) "" alts' <- parallelAffines $ (zip alts altsBinderTys) <&> \(Abs b body, reqBinderTy) -> do checkB b \b' -> do checkTypesEq (sink reqBinderTy) (sink $ binderType b') - Abs b' <$> checkBlock (sink effTy') body + Abs b' <$> checkWithEffTy (sink effTy') body return $ Case scrut' alts' effTy' ApplyMethod effTy dict i args -> do effTy' <- checkEffTy allowedEffs effTy - dict' <- checkE dict + Just dict' <- toMaybeDict <$> checkE dict args' <- mapM checkE args methodTy <- getMethodType dict' i effTy'' <- checkInstantiation methodTy args' checkAlphaEq effTy' effTy'' - return $ ApplyMethod effTy' dict' i args' + return $ ApplyMethod effTy' (toAtom dict') i args' TabCon maybeD ty xs -> do - ty'@(TabPi (TabPiType _ b restTy)) <- ty |: TyKind + ty'@(TyCon (TabPi (TabPiType _ b restTy))) <- checkE ty maybeD' <- mapM renameM maybeD -- TODO: check xs' <- case fromConstAbs (Abs b restTy) of HoistSuccess elTy -> forM xs (|: elTy) @@ -319,31 +258,81 @@ instance IRRep r => CheckableWithEffects r (Expr r) where -- each index from the ix dict. HoistFailure _ -> forM xs checkE return $ TabCon maybeD' ty' xs' + Project resultTy i x -> do + x' <-checkE x + resultTy' <- checkE resultTy + resultTy'' <- checkProject i x' + checkTypesEq resultTy' resultTy'' + return $ Project resultTy' i x' + Unwrap resultTy x -> do + resultTy' <- checkE resultTy + (x', TyCon (NewtypeTyCon con)) <- checkAndGetType x + resultTy'' <- snd <$> unwrapNewtypeType con + checkTypesEq resultTy' resultTy'' + return $ Unwrap resultTy' x' instance CheckableE CoreIR TyConParams where checkE (TyConParams expls params) = TyConParams expls <$> mapM checkE params -instance CheckableE CoreIR DictExpr where +instance IRRep r => CheckableE r (Stuck r) where checkE = \case - InstanceDict instanceName args -> do - instanceName' <- renameM instanceName - args' <- mapM checkE args - instanceDef <- lookupInstanceDef instanceName' - void $ checkInstantiation instanceDef args' - return $ InstanceDict instanceName' args' + Var name -> do + name' <- checkE name + case getType name' of + RawRefTy _ -> affineUsed $ atomVarName name' + _ -> return () + return $ Var name' + StuckUnwrap x -> do + x' <- checkE x + TyCon (NewtypeTyCon _) <- queryStuckType x' + return $ StuckUnwrap x' + StuckProject i x -> do + x' <-checkE x + x'' <- mkStuck x' + void $ checkProject i x'' + return $ StuckProject i x' + StuckTabApp f x -> do + f' <- checkE f + tabTy <- queryStuckType f' + x' <- checkE x + void $ checkTabApp tabTy x' + return $ StuckTabApp f' x' InstantiatedGiven given args -> do - (given', Pi piTy) <- checkAndGetType given + given' <- checkE given + TyCon (Pi piTy) <- queryStuckType given' args' <- mapM checkE args EffTy Pure _ <- checkInstantiation piTy args' return $ InstantiatedGiven given' args' - SuperclassProj d i -> SuperclassProj <$> checkE d <*> pure i -- TODO: check index in range + SuperclassProj i d -> SuperclassProj <$> pure i <*> checkE d -- TODO: check index in range + PtrVar t v -> PtrVar t <$> renameM v + RepValAtom repVal -> RepValAtom <$> renameM repVal -- TODO: check + LiftSimp t x -> LiftSimp <$> checkE t <*> renameM x -- TODO: check + LiftSimpFun t x -> LiftSimpFun <$> checkE t <*> renameM x -- TODO: check + ACase scrut alts resultTy -> ACase <$> renameM scrut <*> mapM renameM alts <*> checkE resultTy -- TODO: check + TabLam lam -> TabLam <$> renameM lam -- TODO: check + +depPairLeftTy :: DepPairType r n -> Type r n +depPairLeftTy (DepPairType _ (_:>ty) _) = ty +{-# INLINE depPairLeftTy #-} + +instance IRRep r => CheckableE r (DictCon r) where + checkE = \case + InstanceDict ty instanceName args -> do + ty' <- checkE ty + instanceName' <- renameM instanceName + args' <- mapM checkE args + instanceDef <- lookupInstanceDef instanceName' + void $ checkInstantiation instanceDef args' + return $ InstanceDict ty' instanceName' args' IxFin n -> IxFin <$> n |: NatTy - DataData ty -> DataData <$> ty |: TyKind + DataData dataTy -> DataData <$> checkE dataTy + IxRawFin n -> IxRawFin <$> n |: IdxRepTy + IxSpecialized v params -> IxSpecialized <$> renameM v <*> mapM checkE params instance IRRep r => CheckableE r (DepPairType r) where checkE (DepPairType expl b ty) = do checkB b \b' -> do - ty' <- ty |: TyKind + ty' <- checkE ty return $ DepPairType expl b' ty' instance CheckableE CoreIR CorePiType where @@ -368,7 +357,7 @@ instance IRRep r => CheckableE r (TabPiType r) where checkE (TabPiType d b resultTy) = do d' <- checkE d checkB b \b' -> do - resultTy' <- resultTy|:TyKind + resultTy' <- checkE resultTy return $ TabPiType d' b' resultTy' instance (BindsNames b, CheckableB r b) => CheckableB r (Nest b) where @@ -386,25 +375,57 @@ instance CheckableE CoreIR CoreLamExpr where lamExpr' <- checkLamExpr (PiType bs effTy) lamExpr return $ CoreLamExpr (CorePiType expl expls bs effTy) lamExpr' -instance IRRep r => CheckableE r (TC r) where +instance IRRep r => CheckableE r (TyCon r) where checkE = \case BaseType b -> return $ BaseType b - ProdType tys -> ProdType <$> mapM (|:TyKind) tys - SumType cs -> SumType <$> mapM (|:TyKind) cs - RefType r a -> RefType <$> r|:TC HeapType <*> a|:TyKind + ProdType tys -> ProdType <$> mapM checkE tys + SumType cs -> SumType <$> mapM checkE cs + RefType r a -> RefType <$> r|:TyCon HeapType <*> checkE a TypeKind -> return TypeKind HeapType -> return HeapType + Pi t -> Pi <$> checkE t + TabPi t -> TabPi <$> checkE t + NewtypeTyCon t -> NewtypeTyCon <$> checkE t + DepPairTy t -> DepPairTy <$> checkE t + DictTy t -> DictTy <$> checkE t + + +instance CheckableE CoreIR DictType where + checkE = \case + DictType sn className params -> do + className' <- renameM className + ClassDef _ Nothing _ _ _ paramBs _ _ <- lookupClassDef className' + params' <- mapM checkE params + void $ checkInstantiation (Abs paramBs UnitE) params' + return $ DictType sn className' params' + IxDictType t -> IxDictType <$> checkE t + DataDictType t -> DataDictType <$> checkE t instance IRRep r => CheckableE r (Con r) where checkE = \case Lit l -> return $ Lit l ProdCon xs -> ProdCon <$> mapM checkE xs SumCon tys tag payload -> do - tys' <- mapM (|:TyKind) tys - unless (0 <= tag && tag < length tys') $ throw TypeErr "Invalid SumType tag" + tys' <- mapM checkE tys + unless (0 <= tag && tag < length tys') $ throwInternal "Invalid SumType tag" payload' <- payload |: (tys' !! tag) return $ SumCon tys' tag payload' HeapVal -> return HeapVal + Lam lam -> Lam <$> checkE lam + DepPair l r ty -> do + l' <- checkE l + ty' <- checkE ty + rTy <- checkInstantiation ty' [l'] + r' <- r |: rTy + return $ DepPair l' r' ty' + Eff eff -> Eff <$> checkE eff + -- TODO: check against cached type + DictConAtom con -> DictConAtom <$> checkE con + NewtypeCon con x -> do + (x', xTy) <- checkAndGetType x + con' <- typeCheckNewtypeCon con xTy + return $ NewtypeCon con' x' + TyConAtom tyCon -> TyConAtom <$> checkE tyCon typeCheckNewtypeCon :: NewtypeCon i -> CType o -> TyperM CoreIR i o (NewtypeCon o) @@ -448,20 +469,20 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where BinOp binop x y -> do x' <- checkE x y' <- checkE y - TC (BaseType xTy) <- return $ getType x' - TC (BaseType yTy) <- return $ getType y' + TyCon (BaseType xTy) <- return $ getType x' + TyCon (BaseType yTy) <- return $ getType y' checkBinOp binop xTy yTy return $ BinOp binop x' y' UnOp unop x -> do x' <- checkE x - TC (BaseType xTy) <- return $ getType x' + TyCon (BaseType xTy) <- return $ getType x' checkUnOp unop xTy return $ UnOp unop x' MiscOp op -> MiscOp <$> checkWithEffects effs op MemOp op -> MemOp <$> checkWithEffects effs op DAMOp op -> DAMOp <$> checkWithEffects effs op RefOp ref m -> do - (ref', TC (RefType h s)) <- checkAndGetType ref + (ref', TyCon (RefType h s)) <- checkAndGetType ref m' <- case m of MGet -> declareEff effs (RWSEffect State h) $> MGet MPut x -> do @@ -475,22 +496,22 @@ instance IRRep r => CheckableWithEffects r (PrimOp r) where declareEff effs (RWSEffect Writer h) return $ MExtend b' x' IndexRef givenTy i -> do - givenTy' <- givenTy |: TyKind - TabPi tabTy <- return s + givenTy' <- checkE givenTy + TyCon (TabPi tabTy) <- return s i' <- checkE i eltTy' <- checkInstantiation tabTy [i'] - checkTypesEq givenTy' (TC $ RefType h eltTy') + checkTypesEq givenTy' (TyCon $ RefType h eltTy') return $ IndexRef givenTy' i' ProjRef givenTy p -> do - givenTy' <- givenTy |: TyKind + givenTy' <- checkE givenTy resultEltTy <- case p of ProjectProduct i -> do - ProdTy tys <- return s + TyCon (ProdType tys) <- return s return $ tys !! i UnwrapNewtype -> do - NewtypeTyCon tc <- return s + TyCon (NewtypeTyCon tc) <- return s snd <$> unwrapNewtypeType tc - checkTypesEq givenTy' (TC $ RefType h resultEltTy) + checkTypesEq givenTy' (TyCon $ RefType h resultEltTy) return $ ProjRef givenTy' p return $ RefOp ref' m' @@ -535,29 +556,29 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where x' <- checkE x y' <- y |: getType x' return $ Select p' x' y' - CastOp t@(TyVar _) e -> CastOp <$> (t|:TyKind) <*> renameM e + CastOp t@(StuckTy _ (Var _)) e -> CastOp <$> checkE t <*> renameM e CastOp destTy e -> do e' <- checkE e - destTy' <- destTy |: TyKind + destTy' <- checkE destTy checkValidCast (getType e') destTy' return $ CastOp destTy' e' - BitcastOp t@(TyVar _) e -> BitcastOp <$> (t|:TyKind) <*> renameM e + BitcastOp t@(StuckTy _ (Var _)) e -> BitcastOp <$> checkE t <*> renameM e BitcastOp destTy e -> do - destTy' <- destTy |: TyKind + destTy' <- checkE destTy e' <- checkE e let sourceTy = getType e' case (destTy', sourceTy) of (BaseTy dbt@(Scalar _), BaseTy sbt@(Scalar _)) | sizeOf sbt == sizeOf dbt -> return $ BitcastOp destTy' e' - _ -> throw TypeErr $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy - UnsafeCoerce t e -> UnsafeCoerce <$> t|:TyKind <*> renameM e - GarbageVal t -> GarbageVal <$> (t|:TyKind) + _ -> throwInternal $ "Invalid bitcast: " ++ pprint sourceTy ++ " -> " ++ pprint destTy + UnsafeCoerce t e -> UnsafeCoerce <$> checkE t <*> renameM e + GarbageVal t -> GarbageVal <$> checkE t SumTag x -> do x' <- checkE x void $ checkSomeSumType $ getType x' return $ SumTag x' ToEnum t x -> do - t' <- t |: TyKind + t' <- checkE t x' <- x |: Word8Ty cases <- checkSomeSumType t' forM_ cases \cty -> checkTypesEq cty UnitTy @@ -568,76 +589,70 @@ instance IRRep r => CheckableWithEffects r (MiscOp r) where x' <- checkE x BaseTy (Scalar _) <- return $ getType x' return $ ShowScalar x' - ThrowError ty -> ThrowError <$> (ty|:TyKind) + ThrowError ty -> ThrowError <$> checkE ty ThrowException ty -> ThrowException <$> do declareEff effs ExceptionEffect - ty|:TyKind + checkE ty checkSomeSumType :: IRRep r => Type r o -> TyperM r i o [Type r o] checkSomeSumType = \case - SumTy cases -> return cases - NewtypeTyCon con -> do - (_, SumTy cases) <- unwrapNewtypeType con + TyCon (SumType cases) -> return cases + TyCon (NewtypeTyCon con) -> do + (_, TyCon (SumType cases)) <- unwrapNewtypeType con return cases t -> error $ "not some sum type: " ++ pprint t instance IRRep r => CheckableE r (VectorOp r) where checkE = \case VectorBroadcast v ty -> do - ty'@(BaseTy (Vector _ sbt)) <- ty |: TyKind + ty'@(BaseTy (Vector _ sbt)) <- checkE ty v' <- v |: BaseTy (Scalar sbt) return $ VectorBroadcast v' ty' VectorIota ty -> do - ty'@(BaseTy (Vector _ _)) <- ty |: TyKind + ty'@(BaseTy (Vector _ _)) <- checkE ty return $ VectorIota ty' VectorIdx tbl i ty -> do tbl' <- checkE tbl TabTy _ b (BaseTy (Scalar sbt)) <- return $ getType tbl' i' <- i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + ty'@(BaseTy (Vector _ sbt')) <- checkE ty + unless (sbt == sbt') $ throwInternal "Scalar type mismatch" return $ VectorIdx tbl' i' ty' VectorSubref ref i ty -> do ref' <- checkE ref RefTy _ (TabTy _ b (BaseTy (Scalar sbt))) <- return $ getType ref' i' <- i |: binderType b - ty'@(BaseTy (Vector _ sbt')) <- ty |: TyKind - unless (sbt == sbt') $ throw TypeErr "Scalar type mismatch" + ty'@(BaseTy (Vector _ sbt')) <- checkE ty + unless (sbt == sbt') $ throwInternal "Scalar type mismatch" return $ VectorSubref ref' i' ty' -checkBlock :: IRRep r => EffTy r o -> Block r i -> TyperM r i o (Block r o) -checkBlock (EffTy effs ty) (Abs decls result) = - checkDecls effs decls \decls' -> do - result' <- result |: sink ty - return $ Abs decls' result' - checkHof :: IRRep r => EffTy r o -> Hof r i -> TyperM r i o (Hof r o) checkHof (EffTy effs reqTy) = \case For dir ixTy f -> do IxType t d <- checkE ixTy LamExpr (UnaryNest b) body <- return f - TabPi tabTy <- return reqTy + TyCon (TabPi tabTy) <- return reqTy checkBinderType t b \b' -> do - resultTy <- checkInstantiation (sink tabTy) [Var $ binderVar b'] - body' <- checkBlock (EffTy (sink effs) resultTy) body + resultTy <- checkInstantiation (sink tabTy) [toAtom $ binderVar b'] + body' <- checkWithEffTy (EffTy (sink effs) resultTy) body return $ For dir (IxType t d) (LamExpr (UnaryNest b') body') While body -> do let effTy = EffTy effs (BaseTy $ Scalar Word8Type) checkTypesEq reqTy UnitTy - While <$> checkBlock effTy body + While <$> checkWithEffTy effTy body Linearize f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkBinderType xTy b \b' -> do PairTy resultTy fLinTy <- sinkM reqTy - body' <- checkBlock (EffTy Pure resultTy) body - checkTypesEq fLinTy (Pi $ nonDepPiType [sink xTy] Pure resultTy) + body' <- checkWithEffTy (EffTy Pure resultTy) body + checkTypesEq fLinTy (toType $ nonDepPiType [sink xTy] Pure resultTy) return $ Linearize (LamExpr (UnaryNest b') body') x' Transpose f x -> do (x', xTy) <- checkAndGetType x LamExpr (UnaryNest b) body <- return f checkB b \b' -> do - body' <- checkBlock (EffTy Pure (sink xTy)) body + body' <- checkWithEffTy (EffTy Pure (sink xTy)) body checkTypesEq (sink $ binderType b') (sink reqTy) return $ Transpose (LamExpr (UnaryNest b') body') x' RunReader r f -> do @@ -672,13 +687,15 @@ checkHof (EffTy effs reqTy) = \case declareEff effs InitEffect Just <$> dest |: RawRefTy sTy return $ RunState d' s' f' - RunIO body -> RunIO <$> checkBlock (EffTy (extendEffect IOEffect effs) reqTy) body - RunInit body -> RunInit <$> checkBlock (EffTy (extendEffect InitEffect effs) reqTy) body + RunIO body -> RunIO <$> checkWithEffTy (EffTy (extendEffect IOEffect effs) reqTy) body + RunInit body -> RunInit <$> checkWithEffTy (EffTy (extendEffect InitEffect effs) reqTy) body CatchException reqTy' body -> do reqTy'' <- checkE reqTy' checkTypesEq reqTy reqTy'' - TypeCon _ _ (TyConParams _[Type ty]) <- return reqTy'' -- TODO: take more care in unpacking Maybe - body' <- checkBlock (EffTy (extendEffect ExceptionEffect effs) ty) body + -- TODO: take more care in unpacking Maybe + TyCon (NewtypeTyCon (UserADTType _ _ (TyConParams _ [ty]))) <- return reqTy'' + Just ty' <- return $ toMaybeType ty + body' <- checkWithEffTy (EffTy (extendEffect ExceptionEffect effs) ty') body return $ CatchException reqTy'' body' instance IRRep r => CheckableWithEffects r (DAMOp r) where @@ -689,13 +706,13 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkExtends effs effAnn' ixTy' <- checkE ixTy (carry', carryTy') <- checkAndGetType carry - let badCarry = throw TypeErr $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' + let badCarry = throwInternal $ "Seq carry should be a product of raw references, got: " ++ pprint carryTy' case carryTy' of - ProdTy refTys -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry + TyCon (ProdType refTys) -> forM_ refTys \case RawRefTy _ -> return (); _ -> badCarry _ -> badCarry let binderReqTy = PairTy (ixTypeType ixTy') carryTy' checkBinderType binderReqTy b \b' -> do - body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body return $ Seq effAnn' dir ixTy' carry' $ LamExpr (UnaryNest b') body' RememberDest effAnn d lam -> do LamExpr (UnaryNest b) body <- return lam @@ -703,9 +720,9 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkExtends effs effAnn' (d', dTy@(RawRefTy _)) <- checkAndGetType d checkBinderType dTy b \b' -> do - body' <- checkBlock (EffTy (sink effAnn') UnitTy) body + body' <- checkWithEffTy (EffTy (sink effAnn') UnitTy) body return $ RememberDest effAnn' d' $ LamExpr (UnaryNest b') body' - AllocDest ty -> AllocDest <$> ty|:TyKind + AllocDest ty -> AllocDest <$> checkE ty Place ref val -> do val' <- checkE val ref' <- ref |: RawRefTy (getType val') @@ -719,8 +736,8 @@ instance IRRep r => CheckableWithEffects r (DAMOp r) where checkLamExpr :: IRRep r => PiType r o -> LamExpr r i -> TyperM r i o (LamExpr r o) checkLamExpr piTy (LamExpr bs body) = checkB bs \bs' -> do - effTy <- checkInstantiation (sink piTy) (Var <$> bindersVars bs') - body' <- checkBlock effTy body + effTy <- checkInstantiation (sink piTy) (toAtom <$> bindersVars bs') + body' <- checkWithEffTy effTy body return $ LamExpr bs' body' checkDecls @@ -741,32 +758,27 @@ checkRWSAction -> RWS -> LamExpr r i -> TyperM r i o (LamExpr r o) checkRWSAction resultTy referentTy effs rws f = do BinaryLamExpr bH bR body <- return f - checkBinderType (TC HeapType) bH \bH' -> do - let h = Var $ binderVar bH' + checkBinderType (TyCon HeapType) bH \bH' -> do + let h = toAtom $ binderVar bH' let refTy = RefTy h (sink referentTy) checkBinderType refTy bR \bR' -> do let effs' = extendEffect (RWSEffect rws $ sink h) (sink effs) - body' <- checkBlock (EffTy effs' (sink resultTy)) body + body' <- checkWithEffTy (EffTy effs' (sink resultTy)) body return $ BinaryLamExpr bH' bR' body' -checkCaseAltsBinderTys :: IRRep r => Type r n -> TyperM r i n [Type r n] -checkCaseAltsBinderTys ty = case ty of - SumTy types -> return types - NewtypeTyCon t -> case t of - UserADTType _ defName (TyConParams _ params) -> do - def <- lookupTyCon defName - ADTCons cons <- checkInstantiation def params - return [repTy | DataConDef _ _ repTy _ <- cons] - _ -> fail msg - _ -> fail msg - where msg = "Case analysis only supported on ADTs, not on " ++ pprint ty - -checkTabApp :: (IRRep r) => Type r o -> [Atom r o] -> TyperM r i o (Type r o) -checkTabApp ty [] = return ty -checkTabApp ty (i:rest) = do - TabPi tabTy <- return ty - resultTy <- checkInstantiation tabTy [i] - checkTabApp resultTy rest +checkProject :: (IRRep r) => Int -> Atom r o -> TyperM r i o (Type r o) +checkProject i x = case getType x of + TyCon (ProdType tys) -> return $ tys !! i + TyCon (DepPairTy t) | i == 0 -> return $ depPairLeftTy t + TyCon (DepPairTy t) | i == 1 -> do + xFst <- reduceProj 0 x + checkInstantiation t [xFst] + xTy -> throwInternal $ "Not a product type:" ++ pprint xTy + +checkTabApp :: (IRRep r) => Type r o -> Atom r o -> TyperM r i o (Type r o) +checkTabApp ty i = do + TyCon (TabPi tabTy) <- return ty + checkInstantiation tabTy [i] checkInstantiation :: forall r e body i o . @@ -782,7 +794,7 @@ checkInstantiation abTop xsTop = do checkTypesEq (getType x) (binderType b) rest <- applySubst (b@>SubstVal x) (Abs bs body) go rest xs - go _ _ = throw ZipErr "Wrong number of args" + go _ _ = throwInternal "Wrong number of args" checkIntBaseType :: Fallible m => BaseType -> m () checkIntBaseType t = case t of @@ -797,7 +809,7 @@ checkIntBaseType t = case t of Word32Type -> return () Word64Type -> return () _ -> notInt - notInt = throw TypeErr $ + notInt = throwInternal $ "Expected a fixed-width scalar integer type, but found: " ++ pprint t checkFloatBaseType :: Fallible m => BaseType -> m () @@ -810,13 +822,13 @@ checkFloatBaseType t = case t of Float64Type -> return () Float32Type -> return () _ -> notFloat - notFloat = throw TypeErr $ + notFloat = throwInternal $ "Expected a fixed-width scalar floating-point type, but found: " ++ pprint t checkValidCast :: (Fallible1 m, IRRep r) => Type r n -> Type r n -> m n () -checkValidCast (BaseTy l) (BaseTy r) = checkValidBaseCast l r +checkValidCast (TyCon (BaseType l)) (TyCon (BaseType r)) = checkValidBaseCast l r checkValidCast sourceTy destTy = - throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy + throwInternal $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy checkValidBaseCast :: Fallible m => BaseType -> BaseType -> m () checkValidBaseCast (PtrType _) (PtrType _) = return () @@ -826,13 +838,13 @@ checkValidBaseCast (Scalar _) (Scalar _) = return () checkValidBaseCast sourceTy@(Vector sourceSizes _) destTy@(Vector destSizes _) = assertEq sourceSizes destSizes $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy checkValidBaseCast sourceTy destTy = - throw TypeErr $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy + throwInternal $ "Can't cast " ++ pprint sourceTy ++ " to " ++ pprint destTy scalarOrVectorLike :: Fallible m => BaseType -> ScalarBaseType -> m BaseType scalarOrVectorLike x sbt = case x of Scalar _ -> return $ Scalar sbt Vector sizes _ -> return $ Vector sizes sbt - _ -> throw CompilerErr "only scalar or vector base types should occur here" + _ -> throwInternal $ "only scalar or vector base types should occur here" data ArgumentType = SomeFloatArg | SomeIntArg | SomeUIntArg @@ -895,7 +907,7 @@ instance IRRep r => CheckableE r (EffectRow r) where checkE (EffectRow effs effTail) = do effs' <- eSetFromList <$> forM (eSetToList effs) \eff -> case eff of RWSEffect rws v -> do - v' <- v |: TC HeapType + v' <- v |: TyCon HeapType return $ RWSEffect rws v' ExceptionEffect -> return ExceptionEffect IOEffect -> return IOEffect diff --git a/src/lib/ConcreteSyntax.hs b/src/lib/ConcreteSyntax.hs index 06cbc702b..6104c4df1 100644 --- a/src/lib/ConcreteSyntax.hs +++ b/src/lib/ConcreteSyntax.hs @@ -9,7 +9,7 @@ module ConcreteSyntax ( keyWordStrs, showPrimName, parseUModule, parseUModuleDeps, finishUModuleParse, preludeImportBlock, mustParseSourceBlock, - pattern Binary, pattern Prefix, pattern Postfix, pattern Identifier) where + pattern Identifier) where import Control.Monad.Combinators.Expr qualified as Expr import Control.Monad.Reader @@ -17,23 +17,19 @@ import Data.Char import Data.Either import Data.Functor import Data.List.NonEmpty (NonEmpty (..)) -import Data.Map qualified as M import Data.String (fromString) import Data.Text (Text) import Data.Text qualified as T import Data.Text.Encoding qualified as T -import Data.Tuple import Data.Void import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) +import Err import Lexing -import Name -import SourceInfo import Types.Core import Types.Source import Types.Primitives -import qualified Types.OpNames as P import Util -- TODO: implement this more efficiently rather than just parsing the whole @@ -61,7 +57,7 @@ parseUModule name s = do {-# SCC parseUModule #-} preludeImportBlock :: SourceBlock -preludeImportBlock = SourceBlock 0 0 LogNothing "" $ Misc $ ImportModule Prelude +preludeImportBlock = SourceBlock 0 0 "" mempty (Misc $ ImportModule Prelude) sourceBlocks :: Parser [SourceBlock] sourceBlocks = manyTill (sourceBlock <* outputLines) eof @@ -72,8 +68,8 @@ mustParseSourceBlock s = mustParseit s sourceBlock -- === helpers for target ADT === -interp_operator :: String -> Bin' -interp_operator = \case +interpOperator :: String -> Bin +interpOperator = \case "&>" -> DepAmpersand "." -> Dot ",>" -> DepComma @@ -84,23 +80,10 @@ interp_operator = \case "->>" -> ImplicitArrow "=>" -> FatArrow "=" -> CSEqual - name -> EvalBinOp $ "(" <> name <> ")" + name -> EvalBinOp $ fromString $ "(" <> name <> ")" -pattern Binary :: Bin' -> Group -> Group -> Group -pattern Binary op lhs rhs <- (WithSrc _ (CBin (WithSrc _ op) lhs rhs)) where - Binary op lhs rhs = joinSrc lhs rhs $ CBin (WithSrc emptySrcPosCtx op) lhs rhs - -pattern Prefix :: SourceName -> Group -> Group -pattern Prefix op g <- (WithSrc _ (CPrefix op g)) where - Prefix op g = WithSrc emptySrcPosCtx $ CPrefix op g - -pattern Postfix :: SourceName -> Group -> Group -pattern Postfix op g <- (WithSrc _ (CPostfix op g)) where - Postfix op g = WithSrc emptySrcPosCtx $ CPostfix op g - -pattern Identifier :: SourceName -> Group -pattern Identifier name <- (WithSrc _ (CIdentifier name)) where - Identifier name = WithSrc emptySrcPosCtx $ CIdentifier name +pattern Identifier :: SourceName -> GroupW +pattern Identifier name <- (WithSrcs _ _ (CLeaf (CIdentifier name))) -- === Parser (top-level structure) === @@ -108,25 +91,23 @@ sourceBlock :: Parser SourceBlock sourceBlock = do offset <- getOffset pos <- getSourcePos - (src, (level, b)) <- withSource $ withRecovery recover $ do - level <- logLevel <|> logTime <|> logBench <|> return LogNothing - b <- sourceBlock' - return (level, b) - return $ SourceBlock (unPos (sourceLine pos)) offset level src b + (src, (lexInfo, b)) <- withSource $ withLexemeInfo $ withRecovery recover $ sourceBlock' + let lexInfo' = lexInfo { lexemeInfo = lexemeInfo lexInfo <&> \(t, (l, r)) -> (t, (l-offset, r-offset))} + return $ SourceBlock (unPos (sourceLine pos)) offset src lexInfo' b -recover :: ParseError Text Void -> Parser (LogLevel, SourceBlock') +recover :: ParseError Text Void -> Parser SourceBlock' recover e = do pos <- liftM statePosState getParserState reachedEOF <- try (mayBreak sc >> eof >> return True) <|> return False consumeTillBreak let errmsg = errorBundlePretty (ParseErrorBundle (e :| []) pos) - return (LogNothing, UnParseable reachedEOF errmsg) + return $ UnParseable reachedEOF errmsg importModule :: Parser SourceBlock' importModule = Misc . ImportModule . OrdinaryModule <$> do keyWord ImportKW - s <- anyCaseName + WithSrc _ s <- anyCaseName eol return s @@ -138,7 +119,7 @@ declareForeign = do void $ label "type annotation" $ sym ":" ty <- cGroup eol - return $ DeclareForeign foreignName b ty + return $ DeclareForeign (fmap fromString foreignName) b ty declareCustomLinearization :: Parser SourceBlock' declareCustomLinearization = do @@ -152,34 +133,6 @@ declareCustomLinearization = do consumeTillBreak :: Parser () consumeTillBreak = void $ manyTill anySingle $ eof <|> void (try (eol >> eol)) -logLevel :: Parser LogLevel -logLevel = do - void $ try $ lexeme $ char '%' >> string "passes" - passes <- many passName - eol - case passes of - [] -> return LogAll - _ -> return $ LogPasses passes - -logTime :: Parser LogLevel -logTime = do - void $ try $ lexeme $ char '%' >> string "time" - eol - return PrintEvalTime - -logBench :: Parser LogLevel -logBench = do - void $ try $ lexeme $ char '%' >> string "bench" - benchName <- strLit - eol - return $ PrintBench benchName - -passName :: Parser PassName -passName = choice [thisNameString s $> x | (s, x) <- passNames] - -passNames :: [(Text, PassName)] -passNames = [(T.pack $ show x, x) | x <- [minBound..maxBound]] - sourceBlock' :: Parser SourceBlock' sourceBlock' = proseBlock @@ -189,39 +142,40 @@ sourceBlock' = <|> hidden (some eol >> return (Misc EmptyLines)) <|> hidden (sc >> eol >> return (Misc CommentLine)) -topDecl :: Parser CTopDecl -topDecl = withSrc $ topDecl' <* eolf +topDecl :: Parser CTopDeclW +topDecl = withSrcs topDecl' <* eolf -topDecl' :: Parser CTopDecl' +topDecl' :: Parser CTopDecl topDecl' = dataDef <|> structDef <|> interfaceDef <|> (CInstanceDecl <$> instanceDef True) <|> (CInstanceDecl <$> instanceDef False) - <|> effectDef proseBlock :: Parser SourceBlock' -proseBlock = label "prose block" $ char '\'' >> fmap (Misc . ProseBlock . fst) (withSource consumeTillBreak) +proseBlock = label "prose block" $ + char '\'' >> fmap (Misc . ProseBlock . fst) (withSource consumeTillBreak) topLevelCommand :: Parser SourceBlock' topLevelCommand = importModule <|> declareForeign <|> declareCustomLinearization - <|> (Misc . QueryEnv <$> envQuery) + -- <|> (Misc . QueryEnv <$> envQuery) <|> explicitCommand "top-level command" -envQuery :: Parser EnvQuery -envQuery = string ":debug" >> sc >> ( - (DumpSubst <$ (string "env" >> sc)) - <|> (InternalNameInfo <$> (string "iname" >> sc >> rawName)) - <|> (SourceNameInfo <$> (string "sname" >> sc >> anyName))) - <* eol - where - rawName :: Parser RawName - rawName = undefined -- RawName <$> (fromString <$> anyName) <*> intLit +_envQuery :: Parser EnvQuery +_envQuery = error "not implemented" +-- string ":debug" >> sc >> ( +-- (DumpSubst <$ (string "env" >> sc)) +-- <|> (InternalNameInfo <$> (string "iname" >> sc >> rawName)) +-- <|> (SourceNameInfo <$> (string "sname" >> sc >> anyName))) +-- <* eol +-- where +-- rawName :: Parser RawName +-- rawName = RawName <$> (fromString <$> anyName) <*> intLit explicitCommand :: Parser SourceBlock' explicitCommand = do @@ -237,13 +191,13 @@ explicitCommand = do b <- cBlock <* eolf e <- case b of ExprBlock e -> return e - IndentedBlock decls -> return $ WithSrc emptySrcPosCtx $ CDo $ IndentedBlock decls + IndentedBlock sid decls -> withSrcs $ return $ CDo $ IndentedBlock sid decls return $ case (e, cmd) of - (WithSrc _ (CIdentifier v), GetType) -> Misc $ GetNameType v + (WithSrcs sid _ (CLeaf (CIdentifier v)), GetType) -> Misc $ GetNameType (WithSrc sid v) _ -> Command cmd e -type CDefBody = ([(SourceName, Group)], [(LetAnn, CDef)]) -structDef :: Parser CTopDecl' +type CDefBody = ([(SourceNameW, GroupW)], [(LetAnn, CDef)]) +structDef :: Parser CTopDecl structDef = do keyWord StructKW tyName <- anyName @@ -263,22 +217,22 @@ structDef = do funDefLetWithAnn :: Parser (LetAnn, CDef) funDefLetWithAnn = do - ann <- noInline <|> return PlainLet + ann <- topLetAnn <|> return PlainLet def <- funDefLet return (ann, def) -dataDef :: Parser CTopDecl' +dataDef :: Parser CTopDecl dataDef = do keyWord DataKW tyName <- anyName (params, givens) <- typeParams dataCons <- onePerLine do dataConName <- anyName - dataConArgs <- optExplicitParams + dataConArgs <- optional explicitParams return (dataConName, dataConArgs) return $ CData tyName params givens dataCons -interfaceDef :: Parser CTopDecl' +interfaceDef :: Parser CTopDecl interfaceDef = do keyWord InterfaceKW className <- anyName @@ -291,27 +245,7 @@ interfaceDef = do return (methodName, ty) return $ CInterface className params methodDecls -effectDef :: Parser CTopDecl' -effectDef = do - keyWord EffectKW - effName <- anyName - sigs <- opSigList - return $ CEffectDecl (fromString effName) sigs - -opSigList :: Parser [(SourceName, UResumePolicy, Group)] -opSigList = onePerLine do - policy <- resumePolicy - v <- anyName - void $ sym ":" - ty <- cGroup - return (fromString v, policy, ty) - -resumePolicy :: Parser UResumePolicy -resumePolicy = (keyWord JmpKW $> UNoResume) - <|> (keyWord DefKW $> ULinearResume) - <|> (keyWord CtlKW $> UAnyResume) - -nameAndType :: Parser (SourceName, Group) +nameAndType :: Parser (SourceNameW, GroupW) nameAndType = do n <- anyName sym ":" @@ -319,20 +253,25 @@ nameAndType = do return (n, arg) topLetOrExpr :: Parser SourceBlock' -topLetOrExpr = withSrc topLet >>= \case - WithSrc _ (CSDecl ann (CExpr e)) -> do +topLetOrExpr = topLet >>= \case + WithSrcs _ _ (CSDecl ann (CExpr e)) -> do when (ann /= PlainLet) $ fail "Cannot annotate expressions" return $ Command (EvalExpr (Printed Nothing)) e d -> return $ TopDecl d -topLet :: Parser CTopDecl' -topLet = do - lAnn <- noInline <|> return PlainLet +topLet :: Parser CTopDeclW +topLet = withSrcs do + lAnn <- topLetAnn <|> return PlainLet decl <- cDecl return $ CSDecl lAnn decl -noInline :: Parser LetAnn -noInline = (char '@' >> string "noinline" $> NoInlineLet) <* nextLine +topLetAnn :: Parser LetAnn +topLetAnn = do + void $ char '@' + ann <- (string "inline" $> InlineLet) + <|> (string "noinline" $> NoInlineLet) + nextLine + return ann onePerLine :: Parser a -> Parser [a] onePerLine p = liftM (:[]) p @@ -345,15 +284,16 @@ cBlock :: Parser CSBlock cBlock = indentedBlock <|> ExprBlock <$> cGroup indentedBlock :: Parser CSBlock -indentedBlock = withIndent $ - IndentedBlock <$> (withSrc cDecl `sepBy1` (semicolon <|> try nextLine)) +indentedBlock = withIndent do + WithSrcs sid _ decls <- withSrcs $ withSrcs cDecl `sepBy1` (void semicolon <|> try nextLine) + return $ IndentedBlock sid decls -cDecl :: Parser CSDecl' +cDecl :: Parser CSDecl cDecl = (CDefDecl <$> funDefLet) <|> simpleLet <|> (keyWord PassKW >> return CPass) -simpleLet :: Parser CSDecl' +simpleLet :: Parser CSDecl simpleLet = do lhs <- cGroupNoEqual next <- nextChar @@ -367,14 +307,14 @@ instanceDef isNamed = do optNameAndArgs <- case isNamed of False -> keyWord InstanceKW $> Nothing True -> keyWord NamedInstanceKW >> do - name <- fromString <$> anyName + name <- anyName args <- (sym ":" >> return Nothing) - <|> ((Just <$> parens (commaSep cParenGroup)) <* sym "->") + <|> ((Just <$> parenList cParenGroup) <* sym "->") return $ Just (name, args) className <- anyName args <- argList givens <- optional givenClause - methods <- withIndent $ withSrc cDecl `sepBy1` try nextLine + methods <- withIndent $ (withSrcs cDecl) `sepBy1` try nextLine return $ CInstanceDef className args givens methods optNameAndArgs funDefLet :: Parser CDef @@ -399,36 +339,37 @@ explicitness = (sym "->" $> ExplicitApp) <|> (sym "->>" $> ImplicitApp) -- Intended for occurrences, like `foo(x, y, z)` (cf. defParamsList). -argList :: Parser [Group] -argList = immediateParens (commaSep cParenGroup) +argList :: Parser [GroupW] +argList = do + WithSrcs _ _ gs <- withSrcs $ bracketedGroup immediateLParen rParen cParenGroup + return gs immediateLParen :: Parser () immediateLParen = label "'(' (without preceding whitespace)" do nextChar >>= \case '(' -> precededByWhitespace >>= \case True -> empty - False -> charLexeme '(' + False -> lParen _ -> empty -immediateParens :: Parser a -> Parser a -immediateParens p = bracketed immediateLParen rParen p - -- Putting `sym =` inside the cases gives better errors. -typeParams :: Parser (ExplicitParams, Maybe GivenClause) +typeParams :: Parser (Maybe ExplicitParams, Maybe GivenClause) typeParams = (explicitParamsAndGivens <* sym "=") - <|> (return ([], Nothing) <* sym "=") + <|> (return (Nothing, Nothing) <* sym "=") -explicitParamsAndGivens :: Parser (ExplicitParams, Maybe GivenClause) -explicitParamsAndGivens = (,) <$> explicitParams <*> optional givenClause - -optExplicitParams :: Parser ExplicitParams -optExplicitParams = label "optional parameter list" $ - explicitParams <|> return [] +explicitParamsAndGivens :: Parser (Maybe ExplicitParams, Maybe GivenClause) +explicitParamsAndGivens = (,) <$> (Just <$> explicitParams) <*> optional givenClause explicitParams :: Parser ExplicitParams explicitParams = label "parameter list in parentheses (without preceding whitespace)" $ - immediateParens $ commaSep cParenGroup + withSrcs $ bracketedGroup immediateLParen rParen cParenGroup + +parenList :: Parser GroupW -> Parser BracketedGroup +parenList p = withSrcs $ bracketedGroup lParen rParen p + +bracketedGroup :: Parser () -> Parser () -> Parser GroupW -> Parser [GroupW] +bracketedGroup l r p = bracketed l r $ commaSep p noGap :: Parser () noGap = precededByWhitespace >>= \case @@ -436,66 +377,63 @@ noGap = precededByWhitespace >>= \case False -> return () givenClause :: Parser GivenClause -givenClause = keyWord GivenKW >> do - (,) <$> parens (commaSep cGroup) - <*> optional (parens (commaSep cGroup)) +givenClause = do + keyWord GivenKW + (,) <$> parenList cGroup <*> optional (parenList cGroup) withClause :: Parser WithClause -withClause = keyWord WithKW >> parens (commaSep cGroup) - -arrowOptEffs :: Parser (Maybe CEffs) -arrowOptEffs = sym "->" >> optional cEffs +withClause = keyWord WithKW >> parenList cGroup cEffs :: Parser CEffs -cEffs = braces do +cEffs = withSrcs $ braces do effs <- commaSep cGroupNoPipe effTail <- optional $ sym "|" >> cGroup return (effs, effTail) commaSep :: Parser a -> Parser [a] -commaSep p = p `sepBy` sym "," +commaSep p = sepBy p (sym ",") -cParenGroup :: Parser Group -cParenGroup = withSrc (CGivens <$> givenClause) <|> cGroup +cParenGroup :: Parser GroupW +cParenGroup = (withSrcs (CGivens <$> givenClause)) <|> cGroup -cGroup :: Parser Group +cGroup :: Parser GroupW cGroup = makeExprParser leafGroup ops -cGroupNoJuxt :: Parser Group +cGroupNoJuxt :: Parser GroupW cGroupNoJuxt = makeExprParser leafGroup $ withoutOp "space" $ withoutOp "." ops -cGroupNoEqual :: Parser Group +cGroupNoEqual :: Parser GroupW cGroupNoEqual = makeExprParser leafGroup $ withoutOp "=" ops -cGroupNoPipe :: Parser Group +cGroupNoPipe :: Parser GroupW cGroupNoPipe = makeExprParser leafGroup $ withoutOp "|" ops -cGroupNoArrow :: Parser Group +cGroupNoArrow :: Parser GroupW cGroupNoArrow = makeExprParser leafGroup $ withoutOp "->" ops -cNullaryLam :: Parser Group' +cNullaryLam :: Parser Group cNullaryLam = do - sym "\\." + void $ sym "\\." body <- cBlock return $ CLambda [] body -cLam :: Parser Group' +cLam :: Parser Group cLam = do - sym "\\" + void $ sym "\\" bs <- many cGroupNoJuxt - mayNotBreak $ sym "." + void $ mayNotBreak $ sym "." body <- cBlock return $ CLambda bs body -cFor :: Parser Group' +cFor :: Parser Group cFor = do kw <- forKW indices <- many cGroupNoJuxt - mayNotBreak $ sym "." + void $ mayNotBreak $ sym "." body <- cBlock return $ CFor kw indices body where forKW = keyWord ForKW $> KFor @@ -503,58 +441,28 @@ cFor = do <|> keyWord RofKW $> KRof <|> keyWord Rof_KW $> KRof_ -cDo :: Parser Group' -cDo = keyWord DoKW >> CDo <$> cBlock +cDo :: Parser Group +cDo = do + keyWord DoKW + CDo <$> cBlock -cCase :: Parser Group' +cCase :: Parser Group cCase = do keyWord CaseKW scrut <- cGroup keyWord OfKW - alts <- onePerLine $ (,) <$> cGroupNoArrow <*> (sym "->" *> cBlock) + alts <- onePerLine cAlt return $ CCase scrut alts --- We support the following syntaxes for `if`: --- - 1-armed then-newline --- if predicate --- then consequent --- if predicate --- then --- consequent1 --- consequent2 --- - 2-armed then-newline else-newline --- if predicate --- then consequent --- else alternate --- and the three other versions where the consequent or the --- alternate are themselves blocks --- - 1-armed then-inline --- if predicate then consequent --- if predicate then --- consequent1 --- consequent2 --- - 2-armed then-inline else-inline --- if predicate then consequent else alternate --- if predicate then consequent else --- alternate1 --- alternate2 --- - Notably, an imagined 2-armed then-inline else-newline --- if predicate then --- consequent1 --- consequent2 --- else alternate --- is not an option, because the indentation lines up badly. To wit, --- one would want the `else` to be indented relative to the `if`, --- but outdented relative to the consequent block, and if the `then` is --- inline there is no indentation level that does that. --- - Last candiate is --- if predicate --- then consequent else alternate --- if predicate --- then consequent else --- alternate1 --- alternate2 -cIf :: Parser Group' +cAlt :: Parser CaseAlt +cAlt = do + pat <- cGroupNoArrow + sym "->" + body <- cBlock + return (pat, body) + +-- see [Note if-syntax] +cIf :: Parser Group cIf = mayNotBreak do keyWord IfKW predicate <- cGroup @@ -563,14 +471,14 @@ cIf = mayNotBreak do thenSameLine :: Parser (CSBlock, Maybe CSBlock) thenSameLine = do - keyWord ThenKW + void $ keyWord ThenKW cBlock >>= \case - IndentedBlock blk -> do + IndentedBlock sid blk -> do let msg = ("No `else` may follow same-line `then` and indented consequent" ++ "; indent and align both `then` and `else`, or write the " ++ "whole `if` on one line.") mayBreak $ noElse msg - return (IndentedBlock blk, Nothing) + return (IndentedBlock sid blk, Nothing) ExprBlock ex -> do alt <- optional $ (keyWord ElseKW >> cBlock) @@ -580,17 +488,17 @@ thenSameLine = do thenNewLine :: Parser (CSBlock, Maybe CSBlock) thenNewLine = withIndent $ do - keyWord ThenKW + void $ keyWord ThenKW cBlock >>= \case - IndentedBlock blk -> do + IndentedBlock sid blk -> do alt <- do -- With `mayNotBreak`, this just forbids inline else noElse ("Same-line `else` may not follow indented consequent;" ++ " put the `else` on the next line.") optional $ do - try $ nextLine >> keyWord ElseKW + void $ try $ nextLine >> keyWord ElseKW cBlock - return (IndentedBlock blk, alt) + return (IndentedBlock sid blk, alt) ExprBlock ex -> do void $ optional $ try nextLine alt <- optional $ keyWord ElseKW >> cBlock @@ -598,59 +506,69 @@ thenNewLine = withIndent $ do noElse :: String -> Parser () noElse msg = (optional $ try $ sc >> keyWord ElseKW) >>= \case - Just () -> fail msg + Just _ -> fail msg Nothing -> return () -leafGroup :: Parser Group -leafGroup = do - leaf <- leafGroup' - postOps <- many postfixGroup - return $ foldl (\accum (op, opLhs) -> joinSrc accum opLhs $ CBin (WithSrc emptySrcPosCtx op) accum opLhs) leaf postOps +leafGroup :: Parser GroupW +leafGroup = leafGroup' >>= appendPostfixGroups where - - leafGroup' :: Parser Group - leafGroup' = withSrc do + leafGroup' :: Parser GroupW + leafGroup' = do next <- nextChar case next of - '_' -> underscore $> CHole - '(' -> (CIdentifier <$> symName) + '_' -> withSrcs $ CLeaf <$> (underscore >> pure CHole) + '(' -> toCLeaf CIdentifier <$> symName <|> cParens '[' -> cBrackets - '\"' -> CString <$> strLit - '\'' -> CChar <$> charLit + '\"' -> toCLeaf CString <$> strLit + '\'' -> toCLeaf CChar <$> charLit '%' -> do - name <- primName + WithSrc sid name <- primName case strToPrimName name of - Just prim -> CPrim prim <$> argList + Just prim -> WithSrcs sid [] <$> CPrim prim <$> argList Nothing -> fail $ "Unrecognized primitive: " ++ name - _ | isDigit next -> ( CNat <$> natLit - <|> CFloat <$> doubleLit) - '\\' -> cNullaryLam <|> cLam + _ | isDigit next -> ( toCLeaf CNat <$> natLit + <|> toCLeaf CFloat <$> doubleLit) + '\\' -> withSrcs (cNullaryLam <|> cLam) -- For exprs include for, rof, for_, rof_ - 'f' -> cFor <|> cIdentifier - 'd' -> cDo <|> cIdentifier - 'r' -> cFor <|> cIdentifier - 'c' -> cCase <|> cIdentifier - 'i' -> cIf <|> cIdentifier + 'f' -> (withSrcs cFor) <|> cIdentifier + 'd' -> (withSrcs cDo) <|> cIdentifier + 'r' -> (withSrcs cFor) <|> cIdentifier + 'c' -> (withSrcs cCase) <|> cIdentifier + 'i' -> (withSrcs cIf) <|> cIdentifier _ -> cIdentifier - postfixGroup :: Parser (Bin', Group) - postfixGroup = noGap >> - ((JuxtaposeNoSpace,) <$> withSrc cParens) - <|> ((JuxtaposeNoSpace,) <$> withSrc cBrackets) - <|> ((Dot,) <$> (try $ char '.' >> withSrc cFieldName)) + appendPostfixGroups :: GroupW -> Parser GroupW + appendPostfixGroups g = + (noGap >> appendPostfixGroup g >>= appendPostfixGroups) + <|> return g + + appendPostfixGroup :: GroupW -> Parser GroupW + appendPostfixGroup g = withSrcs $ + (CJuxtapose False g <$> cParens) + <|> (CJuxtapose False g <$> cBrackets) + <|> appendFieldAccess g + + appendFieldAccess :: GroupW -> Parser Group + appendFieldAccess g = try do + sid <- dot + field <- cFieldName + return $ CBin (WithSrc sid Dot) g field -cFieldName :: Parser Group' -cFieldName = cIdentifier <|> (CNat <$> natLit) +cFieldName :: Parser GroupW +cFieldName = cIdentifier <|> (toCLeaf CNat <$> natLit) -cIdentifier :: Parser Group' -cIdentifier = CIdentifier <$> anyName +cIdentifier :: Parser GroupW +cIdentifier = toCLeaf CIdentifier <$> anyName -cParens :: Parser Group' -cParens = CParens <$> parens (commaSep cParenGroup) +toCLeaf :: (a -> CLeaf) -> WithSrc a -> GroupW +toCLeaf toLeaf (WithSrc sid leaf) = WithSrcs sid [] $ CLeaf $ toLeaf leaf -cBrackets :: Parser Group' -cBrackets = CBrackets <$> brackets (commaSep cGroup) +cParens :: Parser GroupW +cParens = withSrcs $ CParens <$> bracketedGroup lParen rParen cParenGroup + +cBrackets :: Parser GroupW +cBrackets = withSrcs $ CBrackets <$> bracketedGroup lBracket rBracket cGroup -- A `PrecTable` is enough information to (i) remove or replace -- operators for special contexts, and (ii) build the input structure @@ -664,7 +582,7 @@ makeExprParser p tbl = Expr.makeExprParser p tbl' where withoutOp :: SourceName -> PrecTable a -> PrecTable a withoutOp op tbl = map (filter ((/= op) . fst)) tbl -ops :: PrecTable Group +ops :: PrecTable GroupW ops = [ [symOpL "!"] , [juxtaposition] @@ -685,7 +603,6 @@ ops = , [symOpN "==", symOpN "!="] , [symOpL "&&"] , [symOpL "||"] - , [unOpPre "..", unOpPre "..<", unOpPost "..", unOpPost "<.."] , [symOpR "=>"] , [arrow, symOpR "->>"] , [symOpL ">>>"] @@ -693,190 +610,131 @@ ops = , [symOpL "@"] , [symOpN "::"] , [symOpR "$"] - , [symOpL "|"] , [symOpN "+=", symOpN ":="] -- Associate right so the mistaken utterance foo : i:Fin 4 => (..i) -- groups as a bad pi type rather than a bad binder , [symOpR ":"] + , [symOpL "|"] , [symOpR ",>"] , [symOpR "&>"] , [withClausePostfix] , [symOpL "="] ] where other = ("other", anySymOp) - backquote = ("backquote", Expr.InfixL $ opWithSrc $ backquoteName >>= (return . binApp . EvalBinOp)) - juxtaposition = ("space", Expr.InfixL $ opWithSrc $ sc $> (binApp JuxtaposeWithSpace)) + backquote = ("backquote", Expr.InfixL backquoteOp) + juxtaposition = ("space", Expr.InfixL $ sc >> addSrcIdToBinOp (return $ CJuxtapose True)) + withClausePostfix = ("with", Expr.Postfix withClausePostfixOp) arrow = ("->", Expr.InfixR arrowOp) -opWithSrc :: Parser (SrcPos -> a -> a -> a) - -> Parser (a -> a -> a) -opWithSrc p = do - (f, pos) <- withPos p - return $ f pos -{-# INLINE opWithSrc #-} - -anySymOp :: Expr.Operator Parser Group -anySymOp = Expr.InfixL $ opWithSrc $ do - s <- label "infix operator" (mayBreak anySym) - return $ binApp $ interp_operator s - -infixSym :: SourceName -> Parser () -infixSym s = mayBreak $ sym $ T.pack s - -symOpN :: SourceName -> (SourceName, Expr.Operator Parser Group) -symOpN s = (s, Expr.InfixN $ symOp s) - -symOpL :: SourceName -> (SourceName, Expr.Operator Parser Group) -symOpL s = (s, Expr.InfixL $ symOp s) - -symOpR :: SourceName -> (SourceName, Expr.Operator Parser Group) -symOpR s = (s, Expr.InfixR $ symOp s) - -symOp :: SourceName -> Parser (Group -> Group -> Group) -symOp s = opWithSrc $ do - label "infix operator" (infixSym s) - return $ binApp $ interp_operator s - -arrowOp :: Parser (Group -> Group -> Group) -arrowOp = do - WithSrc src effs <- withSrc arrowOptEffs - return \lhs rhs -> WithSrc src $ CArrow lhs effs rhs - -unOpPre :: SourceName -> (SourceName, Expr.Operator Parser Group) -unOpPre s = (s, Expr.Prefix $ unOp CPrefix s) - -unOpPost :: SourceName -> (SourceName, Expr.Operator Parser Group) -unOpPost s = (s, Expr.Postfix $ unOp CPostfix s) - -unOp :: (SourceName -> Group -> Group') -> SourceName -> Parser (Group -> Group) -unOp f s = do - ((), pos) <- withPos $ sym $ fromString s - return \g@(WithSrc grpPos _) -> WithSrc (joinPos (fromPos pos) grpPos) $ f s g - -binApp :: Bin' -> SrcPos -> Group -> Group -> Group -binApp f pos x y = joinSrc3 f' x y $ CBin f' x y - where f' = WithSrc (fromPos pos) f - -withClausePostfix :: (SourceName, Expr.Operator Parser Group) -withClausePostfix = ("with", op) - where - op = Expr.Postfix do - rhs <- withClause - return \lhs -> WithSrc emptySrcPosCtx $ CWith lhs rhs -- TODO: source info - -withSrc :: Parser a -> Parser (WithSrc a) -withSrc p = do - (x, pos) <- withPos p - return $ WithSrc (fromPos pos) x - -joinSrc :: WithSrc a1 -> WithSrc a2 -> a3 -> WithSrc a3 -joinSrc (WithSrc p1 _) (WithSrc p2 _) x = WithSrc (joinPos p1 p2) x - -joinSrc3 :: WithSrc a1 -> WithSrc a2 -> WithSrc a3 -> a4 -> WithSrc a4 -joinSrc3 (WithSrc p1 _) (WithSrc p2 _) (WithSrc p3 _) x = - WithSrc (concatPos [p1, p2, p3]) x - -joinPos :: SrcPosCtx -> SrcPosCtx -> SrcPosCtx -joinPos (SrcPosCtx Nothing _) c@(SrcPosCtx _ _) = c -joinPos c@(SrcPosCtx _ _) (SrcPosCtx Nothing _) = c -joinPos (SrcPosCtx (Just (l, h)) spanId) (SrcPosCtx (Just (l', h')) _) = - SrcPosCtx (Just (min l l', max h h')) spanId - -concatPos :: [SrcPosCtx] -> SrcPosCtx -concatPos [] = error "concatPos: unexpected empty [SrcPosCtx]" -concatPos (pos:rest) = foldl joinPos pos rest - --- === primitive constructors and operators === - -strToPrimName :: String -> Maybe PrimName -strToPrimName s = M.lookup s primNames - -primNameToStr :: PrimName -> String -primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of - Just s -> s - Nothing -> show prim - -showPrimName :: PrimName -> String -showPrimName prim = primNameToStr prim -{-# NOINLINE showPrimName #-} - -primNames :: M.Map String PrimName -primNames = M.fromList - [ ("ask" , UMAsk), ("mextend", UMExtend) - , ("get" , UMGet), ("put" , UMPut) - , ("while" , UWhile) - , ("linearize", ULinearize), ("linearTranspose", UTranspose) - , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) - , ("runIO" , URunIO ), ("catchException" , UCatchException) - , ("iadd" , binary IAdd), ("isub" , binary ISub) - , ("imul" , binary IMul), ("fdiv" , binary FDiv) - , ("fadd" , binary FAdd), ("fsub" , binary FSub) - , ("fmul" , binary FMul), ("idiv" , binary IDiv) - , ("irem" , binary IRem) - , ("fpow" , binary FPow) - , ("and" , binary BAnd), ("or" , binary BOr ) - , ("not" , unary BNot), ("xor" , binary BXor) - , ("shl" , binary BShL), ("shr" , binary BShR) - , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) - , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) - , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) - , ("fneg" , unary FNeg) - , ("exp" , unary Exp), ("exp2" , unary Exp2) - , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) - , ("sin" , unary Sin), ("cos" , unary Cos) - , ("tan" , unary Tan), ("sqrt" , unary Sqrt) - , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) - , ("log1p", unary Log1p), ("lgamma", unary LGamma) - , ("erf" , unary Erf), ("erfc" , unary Erfc) - , ("TyKind" , UPrimTC $ P.TypeKind) - , ("Float64" , baseTy $ Scalar Float64Type) - , ("Float32" , baseTy $ Scalar Float32Type) - , ("Int64" , baseTy $ Scalar Int64Type) - , ("Int32" , baseTy $ Scalar Int32Type) - , ("Word8" , baseTy $ Scalar Word8Type) - , ("Word32" , baseTy $ Scalar Word32Type) - , ("Word64" , baseTy $ Scalar Word64Type) - , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) - , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) - , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) - , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) - , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) - , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) - , ("Nat" , UNat) - , ("Fin" , UFin) - , ("EffKind" , UEffectRowKind) - , ("NatCon" , UNatCon) - , ("Ref" , UPrimTC $ P.RefType) - , ("HeapType" , UPrimTC $ P.HeapType) - , ("indexRef" , UIndexRef) - , ("alloc" , memOp $ P.IOAlloc) - , ("free" , memOp $ P.IOFree) - , ("ptrOffset", memOp $ P.PtrOffset) - , ("ptrLoad" , memOp $ P.PtrLoad) - , ("ptrStore" , memOp $ P.PtrStore) - , ("throwError" , miscOp $ P.ThrowError) - , ("throwException", miscOp $ P.ThrowException) - , ("dataConTag" , miscOp $ P.SumTag) - , ("toEnum" , miscOp $ P.ToEnum) - , ("outputStream" , miscOp $ P.OutputStream) - , ("cast" , miscOp $ P.CastOp) - , ("bitcast" , miscOp $ P.BitcastOp) - , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) - , ("garbageVal" , miscOp $ P.GarbageVal) - , ("select" , miscOp $ P.Select) - , ("showAny" , miscOp $ P.ShowAny) - , ("showScalar" , miscOp $ P.ShowScalar) - , ("projNewtype" , UProjNewtype) - , ("applyMethod0" , UApplyMethod 0) - , ("applyMethod1" , UApplyMethod 1) - , ("applyMethod2" , UApplyMethod 2) - , ("explicitApply", UExplicitApply) - , ("monoLit", UMonoLiteral) - ] - where - binary op = UBinOp op - baseTy b = UBaseType b - memOp op = UMemOp op - unary op = UUnOp op - ptrTy ty = PtrType (CPU, ty) - miscOp op = UMiscOp op +addSrcIdToBinOp :: Parser (GroupW -> GroupW -> Group) -> Parser (GroupW -> GroupW -> GroupW) +addSrcIdToBinOp op = do + f <- op + sid <- freshSrcId + return \x y -> WithSrcs sid [] $ f x y +{-# INLINE addSrcIdToBinOp #-} + +addSrcIdToUnOp :: Parser (GroupW -> Group) -> Parser (GroupW -> GroupW) +addSrcIdToUnOp op = do + f <- op + sid <- freshSrcId + return \x -> WithSrcs sid [] $ f x +{-# INLINE addSrcIdToUnOp #-} + +backquoteOp :: Parser (GroupW -> GroupW -> GroupW) +backquoteOp = binApp do + WithSrc sid fname <- backquoteName + return $ WithSrc sid $ EvalBinOp fname + +anySymOp :: Expr.Operator Parser GroupW +anySymOp = Expr.InfixL $ binApp do + WithSrc sid s <- label "infix operator" (mayBreak anySym) + return $ WithSrc sid $ interpOperator s + +infixSym :: String -> Parser SrcId +infixSym s = mayBreak $ symWithId $ T.pack s + +symOpN :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpN s = (fromString s, Expr.InfixN $ symOp s) + +symOpL :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpL s = (fromString s, Expr.InfixL $ symOp s) + +symOpR :: String -> (SourceName, Expr.Operator Parser GroupW) +symOpR s = (fromString s, Expr.InfixR $ symOp s) + +symOp :: String -> Parser (GroupW -> GroupW -> GroupW) +symOp s = binApp do + sid <- label "infix operator" (infixSym s) + return $ WithSrc sid $ interpOperator s + +arrowOp :: Parser (GroupW -> GroupW -> GroupW) +arrowOp = addSrcIdToBinOp do + sym "->" + optEffs <- optional cEffs + return \lhs rhs -> CArrow lhs optEffs rhs + +unOpPre :: String -> (SourceName, Expr.Operator Parser GroupW) +unOpPre s = (fromString s, Expr.Prefix $ prefixOp s) + +prefixOp :: String -> Parser (GroupW -> GroupW) +prefixOp s = addSrcIdToUnOp do + symId <- symWithId (fromString s) + return $ CPrefix (WithSrc symId $ fromString s) + +binApp :: Parser BinW -> Parser (GroupW -> GroupW -> GroupW) +binApp f = addSrcIdToBinOp $ CBin <$> f + +withClausePostfixOp :: Parser (GroupW -> GroupW) +withClausePostfixOp = addSrcIdToUnOp do + rhs <- withClause + return \lhs -> CWith lhs rhs + +withSrcs :: Parser a -> Parser (WithSrcs a) +withSrcs p = do + sid <- freshSrcId + (sids, result) <- collectAtomicLexemeIds p + return $ WithSrcs sid sids result + +-- === notes === + +-- note [if-syntax] +-- We support the following syntaxes for `if`: +-- - 1-armed then-newline +-- if predicate +-- then consequent +-- if predicate +-- then +-- consequent1 +-- consequent2 +-- - 2-armed then-newline else-newline +-- if predicate +-- then consequent +-- else alternate +-- and the three other versions where the consequent or the +-- alternate are themselves blocks +-- - 1-armed then-inline +-- if predicate then consequent +-- if predicate then +-- consequent1 +-- consequent2 +-- - 2-armed then-inline else-inline +-- if predicate then consequent else alternate +-- if predicate then consequent else +-- alternate1 +-- alternate2 +-- - Notably, an imagined 2-armed then-inline else-newline +-- if predicate then +-- consequent1 +-- consequent2 +-- else alternate +-- is not an option, because the indentation lines up badly. To wit, +-- one would want the `else` to be indented relative to the `if`, +-- but outdented relative to the consequent block, and if the `then` is +-- inline there is no indentation level that does that. +-- - Last candiate is +-- if predicate +-- then consequent else alternate +-- if predicate +-- then consequent else +-- alternate1 +-- alternate2 diff --git a/src/lib/Core.hs b/src/lib/Core.hs index 8bdac679e..e420b50a7 100644 --- a/src/lib/Core.hs +++ b/src/lib/Core.hs @@ -37,6 +37,7 @@ import Err import IRVariants import Types.Core +import Types.Top import Types.Imp import Types.Primitives import Types.Source @@ -79,10 +80,9 @@ type EnvExtender2 (m::MonadKind2) = forall (n::S). EnvExtender (m n) newtype EnvReaderT (m::MonadKind) (n::S) (a:: *) = EnvReaderT {runEnvReaderT' :: ReaderT (DistinctEvidence n, Env n) m a } deriving ( Functor, Applicative, Monad, MonadFail - , MonadWriter w, Fallible, Searcher, Alternative) + , MonadWriter w, Fallible, Alternative) type EnvReaderM = EnvReaderT Identity -type FallibleEnvReaderM = EnvReaderT FallibleM runEnvReaderM :: Distinct n => Env n -> EnvReaderM n a -> a runEnvReaderM bindings m = runIdentity $ runEnvReaderT bindings m @@ -132,6 +132,11 @@ instance MonadIO m => MonadIO (EnvReaderT m n) where deriving instance (Monad m, MonadState s m) => MonadState s (EnvReaderT m o) +instance (Monad m, Catchable m) => Catchable (EnvReaderT m o) where + catchErr (EnvReaderT (ReaderT m)) f = EnvReaderT $ ReaderT \env -> + m env `catchErr` \err -> runReaderT (runEnvReaderT' $ f err) env + {-# INLINE catchErr #-} + -- === Instances for Name monads === instance (SinkableE e, EnvReader m) @@ -389,12 +394,6 @@ withFreshBinders (binding:rest) cont = do cont (Nest b bs) (sink (binderName b) : vs) -getInstanceDicts :: EnvReader m => ClassName n -> m n [InstanceName n] -getInstanceDicts name = do - env <- withEnv moduleEnv - return $ M.findWithDefault [] name $ instanceDicts $ envSynthCandidates env -{-# INLINE getInstanceDicts #-} - -- These `fromNary` functions traverse a chain of unary structures (LamExpr, -- TabLamExpr, CorePiType, respectively) up to the given maxDepth, and return the -- discovered binders packed as the nary structure (NaryLamExpr or PiType), @@ -407,9 +406,10 @@ getInstanceDicts name = do -- structure. Excess binders, if any, are still left in the unary structures. liftLamExpr :: (IRRep r, EnvReader m) - => (forall l m2. EnvReader m2 => Block r l -> m2 l (Block r l)) - -> TopLam r n -> m n (TopLam r n) -liftLamExpr f (TopLam d ty (LamExpr bs body)) = liftM (TopLam d ty) $ liftEnvReaderM $ + => TopLam r n + -> (forall l m2. EnvReader m2 => Expr r l -> m2 l (Expr r l)) + -> m n (TopLam r n) +liftLamExpr (TopLam d ty (LamExpr bs body)) f = liftM (TopLam d ty) $ liftEnvReaderM $ refreshAbs (Abs bs body) \bs' body' -> LamExpr bs' <$> f body' fromNaryForExpr :: IRRep r => Int -> Expr r n -> Maybe (Int, LamExpr r n) @@ -419,39 +419,11 @@ fromNaryForExpr maxDepth = \case extend <|> (Just $ (1, LamExpr (Nest b Empty) body)) where extend = do - expr <- exprBlock body guard $ maxDepth > 1 - (d, LamExpr bs body2) <- fromNaryForExpr (maxDepth - 1) expr + (d, LamExpr bs body2) <- fromNaryForExpr (maxDepth - 1) body return (d + 1, LamExpr (Nest b bs) body2) _ -> Nothing -mkConsListTy :: [Type r n] -> Type r n -mkConsListTy = foldr PairTy UnitTy - -mkConsList :: [Atom r n] -> Atom r n -mkConsList = foldr PairVal UnitVal - -fromConsListTy :: (IRRep r, Fallible m) => Type r n -> m [Type r n] -fromConsListTy ty = case ty of - UnitTy -> return [] - PairTy t rest -> (t:) <$> fromConsListTy rest - _ -> throw CompilerErr $ "Not a pair or unit: " ++ show ty - --- ((...((ans & x{n}) & x{n-1})... & x2) & x1) -> (ans, [x1, ..., x{n}]) -fromLeftLeaningConsListTy :: (IRRep r, Fallible m) => Int -> Type r n -> m (Type r n, [Type r n]) -fromLeftLeaningConsListTy depth initTy = go depth initTy [] - where - go 0 ty xs = return (ty, reverse xs) - go remDepth ty xs = case ty of - PairTy lt rt -> go (remDepth - 1) lt (rt : xs) - _ -> throw CompilerErr $ "Not a pair: " ++ show xs - -fromConsList :: (IRRep r, Fallible m) => Atom r n -> m [Atom r n] -fromConsList xs = case xs of - UnitVal -> return [] - PairVal x rest -> (x:) <$> fromConsList rest - _ -> throw CompilerErr $ "Not a pair or unit: " ++ show xs - type BundleDesc = Int -- length bundleFold :: a -> (a -> a -> a) -> [a] -> (a, BundleDesc) @@ -462,16 +434,10 @@ bundleFold emptyVal pair els = case els of where (tb, td) = bundleFold emptyVal pair t mkBundleTy :: [Type r n] -> (Type r n, BundleDesc) -mkBundleTy = bundleFold UnitTy PairTy +mkBundleTy = bundleFold UnitTy (\x y -> TyCon (ProdType [x, y])) mkBundle :: [Atom r n] -> (Atom r n, BundleDesc) -mkBundle = bundleFold UnitVal PairVal - -trySelectBranch :: IRRep r => Atom r n -> Maybe (Int, Atom r n) -trySelectBranch e = case e of - SumVal _ i value -> Just (i, value) - NewtypeCon con e' | isSumCon con -> trySelectBranch e' - _ -> Nothing +mkBundle = bundleFold UnitVal (\x y -> Con (ProdCon [x, y])) freeAtomVarsList :: forall r e n. (IRRep r, HoistableE e) => e n -> [Name (AtomNameC r) n] freeAtomVarsList = freeVarsList diff --git a/src/lib/Err.hs b/src/lib/Err.hs index b575c048e..7405cfbf9 100644 --- a/src/lib/Err.hs +++ b/src/lib/Err.hs @@ -4,127 +4,322 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Err (Err (..), Errs (..), ErrType (..), Except (..), - ErrCtx (..), SrcTextCtx, - Fallible (..), Catchable (..), catchErrExcept, - FallibleM (..), HardFailM (..), CtxReader (..), - runFallibleM, runHardFail, throw, throwErr, - addContext, addSrcContext, addSrcTextContext, - catchIOExcept, liftExcept, liftExceptAlt, - assertEq, ignoreExcept, - pprint, docAsStr, getCurrentCallStack, printCurrentCallStack, - FallibleApplicativeWrapper, traverseMergingErrs, - SearcherM (..), Searcher (..), runSearcherM) where +{-# LANGUAGE UndecidableInstances #-} + +module Err ( + Err (..), Except (..), ToErr (..), PrintableErr (..), + ParseErr (..), SyntaxErr (..), NameErr (..), TypeErr (..), MiscErr (..), + Fallible (..), Catchable (..), catchErrExcept, + HardFailM (..), runHardFail, throw, + catchIOExcept, liftExcept, liftExceptAlt, + ignoreExcept, getCurrentCallStack, printCurrentCallStack, + ExceptT (..), rootSrcId, SrcId (..), assertEq, throwInternal, + InferenceArgDesc, InfVarDesc (..), HasSrcId (..)) where import Control.Exception hiding (throw) import Control.Applicative import Control.Monad -import Control.Monad.Trans.Maybe import Control.Monad.Identity import Control.Monad.Writer.Strict import Control.Monad.State.Strict import Control.Monad.Reader +import Data.Aeson (ToJSON, ToJSONKey) import Data.Coerce +import Data.Hashable +import Data.List (sort) import Data.Foldable (fold) -import Data.Text (Text) -import Data.Text qualified as T -import Data.Text.Prettyprint.Doc.Render.Text -import Data.Text.Prettyprint.Doc -import GHC.Generics (Generic (..)) +import Data.Store (Store (..)) import GHC.Stack -import System.Environment -import System.IO.Unsafe -import SourceInfo - --- === core API === - -data Err = Err ErrType ErrCtx String deriving (Show, Eq) -newtype Errs = Errs [Err] deriving (Eq, Semigroup, Monoid) - -data ErrType = NoErr - | ParseErr - | SyntaxErr - | TypeErr - | KindErr - | LinErr - | VarDefErr - | UnboundVarErr - | AmbiguousVarErr - | RepeatedVarErr - | RepeatedPatVarErr - | InvalidPatternErr - | CompilerErr - | IRVariantErr - | NotImplementedErr - | DataIOErr - | MiscErr - | RuntimeErr - | ZipErr - | EscapedNameErr - | ModuleImportErr - | MonadFailErr - deriving (Show, Eq) - -type SrcTextCtx = Maybe (Int, Text) -- Int is the offset in the source file -data ErrCtx = ErrCtx - { srcTextCtx :: SrcTextCtx - , srcPosCtx :: SrcPosCtx - , messageCtx :: [String] - , stackCtx :: Maybe [String] } - deriving (Show, Eq, Generic) +import GHC.Generics + +import PPrint + +-- === source info === + +-- XXX: 0 is reserved for the root The IDs are generated from left to right in +-- parsing order, so IDs for lexemes are guaranteed to be sorted correctly. +newtype SrcId = SrcId Int deriving (Show, Eq, Ord, Generic) + +rootSrcId :: SrcId +rootSrcId = SrcId 0 + +class HasSrcId a where + getSrcId :: a -> SrcId + +-- === core errro type === + +data Err = + SearchFailure String -- used as the identity for `Alternative` instances and for MonadFail. + | InternalErr String + | ParseErr ParseErr + | SyntaxErr SrcId SyntaxErr + | NameErr SrcId NameErr + | TypeErr SrcId TypeErr + | RuntimeErr + | MiscErr MiscErr + deriving (Show, Eq) + +type MsgStr = String +type VarStr = String +type TypeStr = String + +data ParseErr = + MiscParseErr MsgStr + deriving (Show, Eq) + +data SyntaxErr = + MiscSyntaxErr MsgStr + | TopLevelArrowBinder + | CantConstrainAnonBinders + | UnexpectedBinder + | OnlyUnaryWithoutParens + | IllegalPattern + | UnexpectedConstraint + | ExpectedIdentifier String + | UnexpectedEffectForm + | UnexpectedMethodDef + | BlockWithoutFinalExpr + | UnexpectedGivenClause + | ArgsShouldHaveParens + | BadEqualSign + | BadColon + | ExpectedAnnBinder + | BadField + | BadPrefix VarStr + deriving (Show, Eq) + +data NameErr = + MiscNameErr MsgStr + | UnboundVarErr VarStr -- name of var + | EscapedNameErr [VarStr] -- names + | RepeatedPatVarErr VarStr + | RepeatedVarErr VarStr + | NotAnOrdinaryVar VarStr + | NotADataCon VarStr + | NotAClassName VarStr + | NotAMethodName VarStr + | AmbiguousVarErr VarStr [String] + | VarDefErr VarStr + deriving (Show, Eq) + +data TypeErr = + MiscTypeErr MsgStr + | CantSynthDict TypeStr + | CantSynthInfVars TypeStr + | NotASynthType TypeStr + | CantUnifySkolem + | OccursCheckFailure VarStr TypeStr + | UnificationFailure TypeStr TypeStr [VarStr] -- expected, actual, inference vars + | DisallowedEffects String String -- allowed, actual + | InferEmptyTable + | ArityErr Int Int -- expected, actual + | PatternArityErr Int Int -- expected, actual + | SumTypeCantFail + | PatTypeErr String String -- expected type constructor (from pattern), actual type (from rhs) + | EliminationErr String String -- expected type constructor, actual type + | IllFormedCasePattern + | NotAMethod VarStr VarStr + | DuplicateMethod VarStr + | MissingMethod VarStr + | WrongArrowErr String String + | AnnotationRequired + | NotAUnaryConstraint TypeStr + | InterfacesNoImplicitParams + | RepeatedOptionalArgs [VarStr] + | UnrecognizedOptionalArgs [VarStr] [VarStr] + | NoFields TypeStr + | TypeMismatch TypeStr TypeStr -- TODO: should we merege this with UnificationFailure + | InferHoleErr + | InferDepPairErr + | InferEmptyCaseEff + | UnexpectedTerm String TypeStr + | CantFindField VarStr TypeStr [VarStr] -- field name, field type, known fields + | TupleLengthMismatch Int Int + | CantReduceType TypeStr + | CantReduceDict + | CantReduceDependentArg + | AmbiguousInferenceVar VarStr TypeStr InfVarDesc + | FFIResultTyErr TypeStr + | FFIArgTyNotScalar TypeStr + deriving (Show, Eq) + +data MiscErr = + MiscMiscErr MsgStr + | ModuleImportErr VarStr + | CantFindModuleSource VarStr + deriving (Show, Eq) + +-- name of function, name of arg +type InferenceArgDesc = (String, String) +data InfVarDesc = + ImplicitArgInfVar InferenceArgDesc + | AnnotationInfVar String -- name of binder + | TypeInstantiationInfVar String -- name of type + | MiscInfVar + deriving (Show, Generic, Eq, Ord) + +-- === ToErr class === + +class ToErr a where + toErr :: SrcId -> a -> Err + +instance ToErr SyntaxErr where toErr = SyntaxErr +instance ToErr NameErr where toErr = NameErr +instance ToErr TypeErr where toErr = TypeErr + +-- === Error messages === + +class PrintableErr a where + printErr :: a -> String + +instance PrintableErr Err where + printErr = \case + SearchFailure s -> "Internal search failure: " ++ s + InternalErr s -> "Internal compiler error: " ++ s ++ "\n" ++ + "Please report this at github.com/google-research/dex-lang/issues\n" + ParseErr e -> "Parse error: " ++ printErr e + SyntaxErr _ e -> "Syntax error: " ++ printErr e + NameErr _ e -> "Name error: " ++ printErr e + TypeErr _ e -> "Type error: " ++ printErr e + MiscErr e -> "Error: " ++ printErr e + RuntimeErr -> "Runtime error" + +instance PrintableErr ParseErr where + printErr = \case + MiscParseErr s -> s + +instance PrintableErr SyntaxErr where + printErr = \case + MiscSyntaxErr s -> s + TopLevelArrowBinder -> + "Arrow binder syntax <- not permitted at the top level, because the binding would have unbounded scope." + CantConstrainAnonBinders -> "can't constrain anonymous binders" + UnexpectedBinder -> "binder must be an identifier or `_`" + OnlyUnaryWithoutParens ->"only unary constructors can form patterns without parens" + IllegalPattern -> "illegal pattern" + UnexpectedConstraint -> "unexpected constraint" + ExpectedIdentifier ctx -> "expected " ++ ctx ++ " to be an identifier" + UnexpectedEffectForm -> + "unexpected effect form; expected one of `Read h`, `Accum h`, `State h`, `Except`, `IO`, " + ++ "or the name of a user-defined effect." + UnexpectedMethodDef -> "unexpected method definition. Expected `def` or `x = ...`." + BlockWithoutFinalExpr -> "block must end in expression" + UnexpectedGivenClause -> "unexpected `given` clause" + ArgsShouldHaveParens -> "argument types should be in parentheses" + BadEqualSign -> "equal sign must be used as a separator for labels or binders, not a standalone operator" + BadColon -> + "colon separates binders from their type annotations, is not a standalone operator.\n" + ++ " If you are trying to write a dependent type, use parens: (i:Fin 4) => (..i)" + ExpectedAnnBinder -> "expected an annotated binder" + BadField -> "field must be a name or an integer" + BadPrefix name -> "prefix (" ++ name ++ ") not legal as a bare expression" + +instance PrintableErr NameErr where + printErr = \case + MiscNameErr s -> s + UnboundVarErr v -> "variable not in scope: " ++ v + EscapedNameErr vs -> "leaked local variables: " ++ unwords vs + RepeatedPatVarErr v -> "variable already defined within pattern: " ++ v + RepeatedVarErr v -> "variable already defined : " ++ v + NotAnOrdinaryVar v -> "not an ordinary variable: " ++ v + NotADataCon v -> "not a data constructor: " ++ v + NotAClassName v -> "not a class name: " ++ v + NotAMethodName v -> "not a method name: " ++ v + -- we sort the lines to make the result a bit more deterministic for quine tests + AmbiguousVarErr v defs -> + "ambiguous occurrence: " ++ v ++ " is defined:\n" + ++ unlines (sort defs) + -- TODO: we see this message a lot. We should improve it by including more information. + -- Ideally we'd provide a link to where the original error happened." + VarDefErr v -> "error in (earlier) definition of variable: " ++ v + +instance PrintableErr TypeErr where + printErr = \case + MiscTypeErr s -> s + FFIResultTyErr t -> "FFI result type should be scalar or pair. Got: " ++ t + FFIArgTyNotScalar t -> "FFI function arguments should be scalar. Got: " ++ t + CantSynthDict t -> "can't synthesize a class dictionary for: " ++ t + CantSynthInfVars t -> "can't synthesize a class dictionary for a type with inference vars: " ++ t + NotASynthType t -> "can't synthesize terms of type: " ++ t + CantUnifySkolem -> "can't unify with skolem vars" + OccursCheckFailure v t -> "occurs check failure: " ++ v ++ " occurs in " ++ t + DisallowedEffects r1 r2 -> "\nAllowed: " ++ pprint r1 ++ + "\nRequested: " ++ pprint r2 + UnificationFailure t1 t2 vs -> "\nExpected: " ++ t1 + ++ "\nActual: " ++ t2 ++ case vs of + [] -> "" + _ -> "\n(Solving for: " ++ unwords vs ++ ")" + InferEmptyTable -> "can't infer type of empty table" + ArityErr n1 n2 -> "wrong number of positional arguments provided. Expected " ++ show n1 ++ " but got " ++ show n2 + PatternArityErr n1 n2 -> "unexpected number of pattern binders. Expected " ++ show n1 ++ " but got " ++ show n2 + SumTypeCantFail -> "sum type constructor in can't-fail pattern" + PatTypeErr patTy rhsTy -> "pattern is for a " ++ patTy ++ "but we're matching against a " ++ rhsTy + EliminationErr expected ty -> "expected a " ++ expected ++ ". Got: " ++ ty + IllFormedCasePattern -> "case patterns must start with a data constructor or variant pattern" + NotAMethod method className -> "unexpected method: " ++ method ++ " is not a method of " ++ className + DuplicateMethod method -> "duplicate method: " ++ method + MissingMethod method -> "missing method: " ++ method + WrongArrowErr expected actual -> "wrong arrow. Expected " ++ expected ++ " got " ++ actual + AnnotationRequired -> "type annotation or constraint required" + NotAUnaryConstraint ty -> "constraint should be a unary function. Got: " ++ ty + InterfacesNoImplicitParams -> "interfaces can't have implicit parameters" + RepeatedOptionalArgs vs -> "repeated names offered:" ++ unwords vs + UnrecognizedOptionalArgs vs accepted -> "unrecognized named arguments: " ++ unwords vs + ++ ". Should be one of: " ++ pprint accepted + NoFields ty -> "can't get fields for type " ++ pprint ty + TypeMismatch expected actual -> "\nExpected: " ++ expected ++ + "\nActual: " ++ actual + InferHoleErr -> "can't infer value of hole" + InferDepPairErr -> "can't infer the type of a dependent pair; please annotate its type" + InferEmptyCaseEff -> "can't infer empty case expressions" + UnexpectedTerm term ty -> "unexpected " ++ term ++ ". Expected: " ++ ty + CantFindField field fieldTy knownFields -> + "can't resolve field " ++ field ++ " of type " ++ fieldTy ++ + "\nKnown fields are: " ++ unwords knownFields + TupleLengthMismatch req actual -> do + "tuple length mismatch. Expected: " ++ show req ++ " but got " ++ show actual + CantReduceType ty -> "Can't reduce type expression: " ++ ty + CantReduceDict -> "Can't reduce dict" + CantReduceDependentArg -> + "dependent functions can only be applied to fully evaluated expressions. " ++ + "Bind the argument to a name before you apply the function." + AmbiguousInferenceVar infVar ty desc -> case desc of + AnnotationInfVar v -> + "couldn't infer type of unannotated binder " <> v + ImplicitArgInfVar (f, argName) -> + "couldn't infer implicit argument `" <> argName <> "` of " <> f + TypeInstantiationInfVar t -> + "couldn't infer instantiation of type " <> t + MiscInfVar -> + "ambiguous type variable: " ++ infVar ++ ": " ++ ty + +instance PrintableErr MiscErr where + printErr = \case + MiscMiscErr s -> s + ModuleImportErr v -> "couldn't import " ++ v + CantFindModuleSource v -> + "couldn't find a source file for module " ++ v ++ + "\nHint: Consider extending --lib-path" + +-- === monads and helpers === class MonadFail m => Fallible m where - throwErrs :: Errs -> m a - addErrCtx :: ErrCtx -> m a -> m a + throwErr :: Err -> m a class Fallible m => Catchable m where - catchErr :: m a -> (Errs -> m a) -> m a + catchErr :: m a -> (Err -> m a) -> m a catchErrExcept :: Catchable m => m a -> m (Except a) catchErrExcept m = catchErr (Success <$> m) (\e -> return $ Failure e) --- We have this in its own class because IO and `Except` can't implement it --- (but FallibleM can) -class Fallible m => CtxReader m where - getErrCtx :: m ErrCtx - --- We have this in its own class because StateT can't implement it --- (but FallibleM, Except and IO all can) -class Fallible m => FallibleApplicative m where - mergeErrs :: m a -> m b -> m (a, b) - -newtype FallibleM a = - FallibleM { fromFallibleM :: ReaderT ErrCtx Except a } - deriving (Functor, Applicative, Monad) - -instance Fallible FallibleM where - throwErrs (Errs errs) = FallibleM $ ReaderT \ambientCtx -> - throwErrs $ Errs [Err errTy (ambientCtx <> ctx) s | Err errTy ctx s <- errs] - {-# INLINE throwErrs #-} - addErrCtx ctx (FallibleM m) = FallibleM $ local (<> ctx) m - {-# INLINE addErrCtx #-} - -instance Catchable FallibleM where - FallibleM m `catchErr` handler = FallibleM $ ReaderT \ctx -> - case runReaderT m ctx of - Failure errs -> runReaderT (fromFallibleM $ handler errs) ctx - Success ans -> return ans - -instance FallibleApplicative FallibleM where - mergeErrs (FallibleM (ReaderT f1)) (FallibleM (ReaderT f2)) = - FallibleM $ ReaderT \ctx -> mergeErrs (f1 ctx) (f2 ctx) - -instance CtxReader FallibleM where - getErrCtx = FallibleM ask - {-# INLINE getErrCtx #-} +catchSearchFailure :: Catchable m => m a -> m (Maybe a) +catchSearchFailure m = (Just <$> m) `catchErr` \case + SearchFailure _ -> return Nothing + err -> throwErr err instance Fallible IO where - throwErrs errs = throwIO errs - {-# INLINE throwErrs #-} - addErrCtx ctx m = do - result <- catchIOExcept m - liftExcept $ addErrCtx ctx result - {-# INLINE addErrCtx #-} + throwErr errs = throwIO errs + {-# INLINE throwErr #-} instance Catchable IO where catchErr cont handler = @@ -132,23 +327,68 @@ instance Catchable IO where Success result -> return result Failure errs -> handler errs -instance FallibleApplicative IO where - mergeErrs m1 m2 = do - result1 <- catchIOExcept m1 - result2 <- catchIOExcept m2 - liftExcept $ mergeErrs result1 result2 +-- === ExceptT type === + +newtype ExceptT m a = ExceptT { runExceptT :: m (Except a) } + +instance Monad m => Functor (ExceptT m) where + fmap = liftM + {-# INLINE fmap #-} + +instance Monad m => Applicative (ExceptT m) where + pure = return + {-# INLINE pure #-} + liftA2 = liftM2 + {-# INLINE liftA2 #-} -runFallibleM :: FallibleM a -> Except a -runFallibleM m = runReaderT (fromFallibleM m) mempty -{-# INLINE runFallibleM #-} +instance Monad m => Monad (ExceptT m) where + return x = ExceptT $ return (Success x) + {-# INLINE return #-} + m >>= f = ExceptT $ runExceptT m >>= \case + Failure errs -> return $ Failure errs + Success x -> runExceptT $ f x + {-# INLINE (>>=) #-} + +instance Monad m => MonadFail (ExceptT m) where + fail s = ExceptT $ return $ Failure $ SearchFailure s + {-# INLINE fail #-} + +instance Monad m => Fallible (ExceptT m) where + throwErr errs = ExceptT $ return $ Failure errs + {-# INLINE throwErr #-} + +instance Monad m => Alternative (ExceptT m) where + empty = throwErr $ SearchFailure "" + {-# INLINE empty #-} + m1 <|> m2 = do + catchSearchFailure m1 >>= \case + Nothing -> m2 + Just x -> return x + {-# INLINE (<|>) #-} + +instance Monad m => Catchable (ExceptT m) where + m `catchErr` f = ExceptT $ runExceptT m >>= \case + Failure errs -> runExceptT $ f errs + Success x -> return $ Success x + {-# INLINE catchErr #-} + +instance MonadState s m => MonadState s (ExceptT m) where + get = lift get + {-# INLINE get #-} + put x = lift $ put x + {-# INLINE put #-} + +instance MonadTrans ExceptT where + lift m = ExceptT $ Success <$> m + {-# INLINE lift #-} -- === Except type === --- Except is isomorphic to `Either Errs` but having a distinct type makes it +-- Except is isomorphic to `Either Err` but having a distinct type makes it -- easier to debug type errors. data Except a = - Failure Errs + Failure Err | Success a deriving (Show, Eq) @@ -169,22 +409,19 @@ instance Monad Except where Success x >>= f = f x {-# INLINE (>>=) #-} --- === FallibleApplicativeWrapper === - --- Wraps a Fallible monad, presenting an applicative interface that sequences --- actions using the error-concatenating `mergeErrs` instead of the default --- abort-on-failure sequencing. - -newtype FallibleApplicativeWrapper m a = - FallibleApplicativeWrapper { fromFallibleApplicativeWrapper :: m a } - deriving (Functor) +instance Alternative Except where + empty = throwErr $ SearchFailure "" + {-# INLINE empty #-} + m1 <|> m2 = do + catchSearchFailure m1 >>= \case + Nothing -> m2 + Just x -> return x + {-# INLINE (<|>) #-} -instance FallibleApplicative m => Applicative (FallibleApplicativeWrapper m) where - pure x = FallibleApplicativeWrapper $ pure x - {-# INLINE pure #-} - liftA2 f (FallibleApplicativeWrapper m1) (FallibleApplicativeWrapper m2) = - FallibleApplicativeWrapper $ fmap (uncurry f) (mergeErrs m1 m2) - {-# INLINE liftA2 #-} +instance Catchable Except where + Success ans `catchErr` _ = Success ans + Failure errs `catchErr` f = f errs + {-# INLINE catchErr #-} -- === HardFail === @@ -222,33 +459,15 @@ instance MonadFail HardFailM where {-# INLINE fail #-} instance Fallible HardFailM where - throwErrs errs = error $ pprint errs - {-# INLINE throwErrs #-} - addErrCtx _ cont = cont - {-# INLINE addErrCtx #-} - -instance FallibleApplicative HardFailM where - mergeErrs cont1 cont2 = (,) <$> cont1 <*> cont2 + throwErr errs = error $ pprint errs + {-# INLINE throwErr #-} -- === convenience layer === -throw :: Fallible m => ErrType -> String -> m a -throw errTy s = throwErrs $ Errs [addCompilerStackCtx $ Err errTy mempty s] +throw :: (ToErr e, Fallible m) => SrcId -> e -> m a +throw sid e = throwErr $ toErr sid e {-# INLINE throw #-} -throwErr :: Fallible m => Err -> m a -throwErr err = throwErrs $ Errs [addCompilerStackCtx err] -{-# INLINE throwErr #-} - -addCompilerStackCtx :: Err -> Err -addCompilerStackCtx (Err ty ctx msg) = Err ty ctx{stackCtx = compilerStack} msg - where -#ifdef DEX_DEBUG - compilerStack = getCurrentCallStack () -#else - compilerStack = stackCtx ctx -#endif - getCurrentCallStack :: () -> Maybe [String] getCurrentCallStack () = #ifdef DEX_DEBUG @@ -264,31 +483,19 @@ printCurrentCallStack :: Maybe [String] -> String printCurrentCallStack Nothing = "" printCurrentCallStack (Just frames) = fold frames -addContext :: Fallible m => String -> m a -> m a -addContext s m = addErrCtx (mempty {messageCtx = [s]}) m -{-# INLINE addContext #-} - -addSrcContext :: Fallible m => SrcPosCtx -> m a -> m a -addSrcContext ctx m = addErrCtx (mempty {srcPosCtx = ctx}) m -{-# INLINE addSrcContext #-} - -addSrcTextContext :: Fallible m => Int -> Text -> m a -> m a -addSrcTextContext offset text m = - addErrCtx (mempty {srcTextCtx = Just (offset, text)}) m - catchIOExcept :: MonadIO m => IO a -> m (Except a) catchIOExcept m = liftIO $ (liftM Success m) `catches` - [ Handler \(e::Errs) -> return $ Failure e - , Handler \(e::IOError) -> return $ Failure $ Errs [Err DataIOErr mempty $ show e] + [ Handler \(e::Err) -> return $ Failure e + , Handler \(e::IOError) -> return $ Failure $ MiscErr $ MiscMiscErr $ show e -- Propagate asynchronous exceptions like ThreadKilled; they are -- part of normal operation (of the live evaluation modes), not -- compiler bugs. , Handler \(e::AsyncException) -> liftIO $ throwIO e - , Handler \(e::SomeException) -> return $ Failure $ Errs [Err CompilerErr mempty $ show e] + , Handler \(e::SomeException) -> return $ Failure $ InternalErr $ show e ] liftExcept :: Fallible m => Except a -> m a -liftExcept (Failure errs) = throwErrs errs +liftExcept (Failure errs) = throwErr errs liftExcept (Success ans) = return ans {-# INLINE liftExcept #-} @@ -305,279 +512,58 @@ ignoreExcept (Success x) = x assertEq :: (HasCallStack, Fallible m, Show a, Pretty a, Eq a) => a -> a -> String -> m () assertEq x y s = if x == y then return () - else throw CompilerErr msg + else throwInternal msg where msg = "assertion failure (" ++ s ++ "):\n" - ++ pprint x ++ " != " ++ pprint y ++ "\n\n" - ++ prettyCallStack callStack ++ "\n" - --- === search monad === - -infix 0 -class (Monad m, Alternative m) => Searcher m where - -- Runs the second computation when the first yields an empty set of results. - -- This is just `<|>` for greedy searchers like `Maybe`, but in other cases, - -- like the list monad, it matters that the second computation isn't run if - -- the first succeeds. - () :: m a -> m a -> m a - --- Adds an extra error case to `FallibleM` so we can give it an Alternative --- instance with an identity element. -newtype SearcherM a = SearcherM { runSearcherM' :: MaybeT FallibleM a } - deriving (Functor, Applicative, Monad) - -runSearcherM :: SearcherM a -> Except (Maybe a) -runSearcherM m = runFallibleM $ runMaybeT (runSearcherM' m) -{-# INLINE runSearcherM #-} - -instance MonadFail SearcherM where - fail _ = SearcherM $ MaybeT $ return Nothing - {-# INLINE fail #-} - -instance Fallible SearcherM where - throwErrs e = SearcherM $ lift $ throwErrs e - {-# INLINE throwErrs #-} - addErrCtx ctx (SearcherM (MaybeT m)) = SearcherM $ MaybeT $ - addErrCtx ctx $ m - {-# INLINE addErrCtx #-} - -instance Alternative SearcherM where - empty = SearcherM $ MaybeT $ return Nothing - SearcherM (MaybeT m1) <|> SearcherM (MaybeT m2) = SearcherM $ MaybeT do - m1 >>= \case - Just ans -> return $ Just ans - Nothing -> m2 + ++ pprint x ++ " != " ++ pprint y -instance Searcher SearcherM where - () = (<|>) - {-# INLINE () #-} - -instance CtxReader SearcherM where - getErrCtx = SearcherM $ lift getErrCtx - {-# INLINE getErrCtx #-} - -instance Searcher [] where - [] m = m - m _ = m - {-# INLINE () #-} - -instance (Monoid w, Searcher m) => Searcher (WriterT w m) where - WriterT m1 WriterT m2 = WriterT (m1 m2) - {-# INLINE () #-} +throwInternal :: (HasCallStack, Fallible m) => String -> m a +throwInternal s = throwErr $ InternalErr $ s ++ "\n" ++ prettyCallStack callStack ++ "\n" instance (Monoid w, Fallible m) => Fallible (WriterT w m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} - addErrCtx ctx (WriterT m) = WriterT $ addErrCtx ctx m - {-# INLINE addErrCtx #-} - -instance Searcher m => Searcher (ReaderT r m) where - ReaderT f1 ReaderT f2 = ReaderT \r -> f1 r f2 r - {-# INLINE () #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} instance Fallible [] where - throwErrs _ = [] - {-# INLINE throwErrs #-} - addErrCtx _ m = m - {-# INLINE addErrCtx #-} + throwErr _ = [] + {-# INLINE throwErr #-} instance Fallible Maybe where - throwErrs _ = Nothing - {-# INLINE throwErrs #-} - addErrCtx _ m = m - {-# INLINE addErrCtx #-} - --- === small pretty-printing utils === --- These are here instead of in PPrint.hs for import cycle reasons - -pprint :: Pretty a => a -> String -pprint x = docAsStr $ pretty x -{-# SCC pprint #-} - -docAsStr :: Doc ann -> String -docAsStr doc = T.unpack $ renderStrict $ layoutPretty layout $ doc - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> lookupEnv "DEX_PPRINT_UNBOUNDED" - -traverseMergingErrs :: (Traversable f, FallibleApplicative m) - => (a -> m b) -> f a -> m (f b) -traverseMergingErrs f xs = - fromFallibleApplicativeWrapper $ traverse (\x -> FallibleApplicativeWrapper $ f x) xs + throwErr _ = Nothing + {-# INLINE throwErr #-} -- === instances === -instance MonadFail FallibleM where - fail s = throw MonadFailErr s - {-# INLINE fail #-} - instance Fallible Except where - throwErrs errs = Failure errs - {-# INLINE throwErrs #-} - - addErrCtx _ (Success ans) = Success ans - addErrCtx ctx (Failure (Errs errs)) = - Failure $ Errs [Err errTy (ctx <> ctx') s | Err errTy ctx' s <- errs] - {-# INLINE addErrCtx #-} - -instance FallibleApplicative Except where - mergeErrs (Success x) (Success y) = Success (x, y) - mergeErrs x y = Failure (getErrs x <> getErrs y) - where getErrs :: Except a -> Errs - getErrs = \case Failure e -> e - Success _ -> mempty + throwErr errs = Failure errs + {-# INLINE throwErr #-} instance MonadFail Except where - fail s = Failure $ Errs [Err CompilerErr mempty s] + fail s = Failure $ SearchFailure s {-# INLINE fail #-} -instance Exception Errs - -instance Show Errs where - show errs = pprint errs - -instance Pretty Err where - pretty (Err e ctx s) = pretty e <> pretty s <> prettyCtx - -- TODO: figure out a more uniform way to newlines - where prettyCtx = case ctx of - ErrCtx _ (SrcPosCtx Nothing _) [] Nothing -> mempty - _ -> hardline <> pretty ctx - -instance Pretty ErrCtx where - pretty (ErrCtx maybeTextCtx maybePosCtx messages stack) = - -- The order of messages is outer-scope-to-inner-scope, but we want to print - -- them starting the other way around (Not for a good reason. It's just what - -- we've always done.) - prettyLines (reverse messages) <> highlightedSource <> prettyStack - where - highlightedSource = case (maybeTextCtx, maybePosCtx) of - (Just (offset, text), SrcPosCtx (Just (start, stop)) _) -> - hardline <> pretty (highlightRegion (start - offset, stop - offset) text) - _ -> mempty - prettyStack = case stack of - Nothing -> mempty - Just s -> hardline <> "Compiler stack trace:" <> nest 2 (hardline <> prettyLines s) - -instance Pretty a => Pretty (Except a) where - pretty (Success x) = "Success:" <+> pretty x - pretty (Failure e) = "Failure:" <+> pretty e - -instance Pretty ErrType where - pretty e = case e of - -- NoErr tags a chunk of output that was promoted into the Err ADT - -- by appending Results. - NoErr -> "" - ParseErr -> "Parse error:" - SyntaxErr -> "Syntax error: " - TypeErr -> "Type error:" - KindErr -> "Kind error:" - LinErr -> "Linearity error: " - IRVariantErr -> "Internal IR validation error: " - VarDefErr -> "Error in (earlier) definition of variable: " - UnboundVarErr -> "Error: variable not in scope: " - AmbiguousVarErr -> "Error: ambiguous variable: " - RepeatedVarErr -> "Error: variable already defined: " - RepeatedPatVarErr -> "Error: variable already defined within pattern: " - InvalidPatternErr -> "Error: not a valid pattern: " - NotImplementedErr -> - "Not implemented:" <> line <> - "Please report this at github.com/google-research/dex-lang/issues\n" <> line - CompilerErr -> - "Compiler bug!" <> line <> - "Please report this at github.com/google-research/dex-lang/issues\n" <> line - DataIOErr -> "IO error: " - MiscErr -> "Error:" - RuntimeErr -> "Runtime error" - ZipErr -> "Zipping error" - EscapedNameErr -> "Leaked local variables:" - ModuleImportErr -> "Module import error: " - MonadFailErr -> "MonadFail error (internal error)" +instance Exception Err instance Fallible m => Fallible (ReaderT r m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} - addErrCtx ctx (ReaderT f) = ReaderT \r -> addErrCtx ctx $ f r - {-# INLINE addErrCtx #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} instance Catchable m => Catchable (ReaderT r m) where ReaderT f `catchErr` handler = ReaderT \r -> f r `catchErr` \e -> runReaderT (handler e) r -instance FallibleApplicative m => FallibleApplicative (ReaderT r m) where - mergeErrs (ReaderT f1) (ReaderT f2) = - ReaderT \r -> mergeErrs (f1 r) (f2 r) - -instance CtxReader m => CtxReader (ReaderT r m) where - getErrCtx = lift getErrCtx - {-# INLINE getErrCtx #-} - -instance Pretty Errs where - pretty (Errs [err]) = pretty err - pretty (Errs errs) = prettyLines errs - instance Fallible m => Fallible (StateT s m) where - throwErrs errs = lift $ throwErrs errs - {-# INLINE throwErrs #-} - addErrCtx ctx (StateT f) = StateT \s -> addErrCtx ctx $ f s - {-# INLINE addErrCtx #-} + throwErr errs = lift $ throwErr errs + {-# INLINE throwErr #-} instance Catchable m => Catchable (StateT s m) where StateT f `catchErr` handler = StateT \s -> f s `catchErr` \e -> runStateT (handler e) s -instance CtxReader m => CtxReader (StateT s m) where - getErrCtx = lift getErrCtx - {-# INLINE getErrCtx #-} - -instance Semigroup ErrCtx where - ErrCtx text (SrcPosCtx p spanId) ctxStrs stk <> ErrCtx text' (SrcPosCtx p' spanId') ctxStrs' stk' = - ErrCtx (leftmostJust text text') - (SrcPosCtx (rightmostJust p p') (rightmostJust spanId spanId')) - (ctxStrs <> ctxStrs') - (leftmostJust stk stk') -- We usually extend errors form the right - -instance Monoid ErrCtx where - mempty = ErrCtx Nothing emptySrcPosCtx [] Nothing - --- === misc util stuff === - -leftmostJust :: Maybe a -> Maybe a -> Maybe a -leftmostJust (Just x) _ = Just x -leftmostJust Nothing y = y - -rightmostJust :: Maybe a -> Maybe a -> Maybe a -rightmostJust = flip leftmostJust - -prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann -prettyLines xs = foldMap (\d -> pretty d <> hardline) xs - -highlightRegion :: (Int, Int) -> Text -> Text -highlightRegion pos@(low, high) s - | low > high || high > T.length s = - error $ "Bad region: \n" ++ show pos ++ "\n" ++ T.unpack s - | otherwise = - -- TODO: flag to control line numbers - -- (disabling for now because it makes quine tests tricky) - -- "Line " ++ show (1 + lineNum) ++ "\n" - allLines !! lineNum <> "\n" <> - T.replicate start " " <> T.replicate (stop - start) "^" <> "\n" - where - allLines = T.lines s - (lineNum, start, stop) = getPosTriple pos allLines - -getPosTriple :: (Int, Int) -> [Text] -> (Int, Int, Int) -getPosTriple (start, stop) lines_ = (lineNum, start - offset, stop') - where - lineLengths = map ((+1) . T.length) lines_ - lineOffsets = cumsum lineLengths - lineNum = maxLT lineOffsets start - offset = lineOffsets !! lineNum - stop' = min (stop - offset) (lineLengths !! lineNum) - -cumsum :: [Int] -> [Int] -cumsum xs = scanl (+) 0 xs - -maxLT :: Ord a => [a] -> a -> Int -maxLT [] _ = 0 -maxLT (x:xs) n = if n < x then -1 - else 1 + maxLT xs n +instance Pretty Err where + pretty e = pretty $ printErr e + +instance ToJSON SrcId +deriving instance ToJSONKey SrcId + +instance Hashable InfVarDesc +instance Store InfVarDesc diff --git a/src/lib/Export.hs b/src/lib/Export.hs index f7ab3184d..466afbb5d 100644 --- a/src/lib/Export.hs +++ b/src/lib/Export.hs @@ -21,6 +21,7 @@ import Foreign.Ptr import Builder import Core import Err +import PPrint import IRVariants import Name import QueryType @@ -29,6 +30,7 @@ import Subst hiding (Rename) import TopLevel import Types.Core import Types.Imp +import Types.Top import Types.Primitives hiding (sizeOf) type ExportAtomNameC = AtomNameC CoreIR @@ -45,14 +47,14 @@ prepareFunctionForExport :: (Mut n, Topper m) => CallingConvention -> CAtom n -> m n ExportNativeFunction prepareFunctionForExport cc f = do naryPi <- case getType f of - Pi piTy -> return piTy - _ -> throw TypeErr "Only first-order functions can be exported" + TyCon (Pi piTy) -> return piTy + _ -> throwErr $ MiscErr $ MiscMiscErr "Only first-order functions can be exported" sig <- liftExportSigM $ corePiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throwErr $ MiscErr $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s - f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (Var <$> xs) + f' <- liftBuilder $ buildCoreLam naryPi \xs -> naryApp (sink f) (toAtom <$> xs) fSimp <- simplifyTopFunction $ coreLamToTopLam f' fImp <- compileTopLevelFun cc fSimp nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -66,7 +68,7 @@ prepareSLamForExport cc f@(TopLam _ naryPi _) = do sig <- liftExportSigM $ simpPiToExportSig cc naryPi closedSig <- case hoistToTop sig of HoistFailure _ -> - throw TypeErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi + throwErr $ MiscErr $ MiscMiscErr $ "Types of exported functions have to be closed terms. Got: " ++ pprint naryPi HoistSuccess s -> return s fImp <- compileTopLevelFun cc f nativeFun <- toCFunction "userFunc" fImp >>= emitObjFile >>= loadObject @@ -87,7 +89,7 @@ instance FromName (Rename r) where fromName = JustRefer newtype ExportSigM (r::IR) (i::S) (o::S) (a:: *) = ExportSigM { - runExportSigM :: SubstReaderT (Rename r) (EnvReaderT FallibleM) i o a } + runExportSigM :: SubstReaderT (Rename r) (EnvReaderT Except) i o a } deriving ( Functor, Applicative, Monad, ScopeReader, EnvExtender, Fallible , EnvReader, SubstReader (Rename r), MonadFail) @@ -95,7 +97,7 @@ liftExportSigM :: (EnvReader m, Fallible1 m) => ExportSigM r n n a -> m n a liftExportSigM cont = do Distinct <- getDistinct env <- unsafeGetEnv - liftExcept $ runFallibleM $ runEnvReaderT env $ runSubstReaderT idSubst $ + liftExcept $ runEnvReaderT env $ runSubstReaderT idSubst $ runExportSigM cont corePiToExportSig :: CallingConvention @@ -103,7 +105,7 @@ corePiToExportSig :: CallingConvention corePiToExportSig cc (CorePiType _ expls tbs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw TypeErr "Only pure functions can be exported" + _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" goArgs cc Empty [] (zipAttrs expls tbs) resultTy simpPiToExportSig :: CallingConvention @@ -111,7 +113,7 @@ simpPiToExportSig :: CallingConvention simpPiToExportSig cc (PiType bs (EffTy effs resultTy)) = do case effs of Pure -> return () - _ -> throw TypeErr "Only pure functions can be exported" + _ -> throwErr $ MiscErr $ MiscMiscErr "Only pure functions can be exported" bs' <- return $ fmapNest (\b -> WithAttrB Explicit b) bs goArgs cc Empty [] bs' resultTy @@ -143,11 +145,11 @@ goResult :: IRRep r => Type r i Nest ExportResult o o' -> ExportSigM r i o' a) -> ExportSigM r i o a goResult ty cont = case ty of - ProdTy [one] -> + TyCon (ProdType [one]) -> goResult one cont - ProdTy (lty:rest) -> + TyCon (ProdType (lty:rest)) -> goResult lty \lres -> - goResult (ProdTy rest) \rres -> + goResult (TyCon (ProdType rest)) \rres -> cont $ lres >>> rres _ -> do ety <- toExportType ty @@ -157,33 +159,29 @@ goResult ty cont = case ty of toExportType :: IRRep r => Type r i -> ExportSigM r i o (ExportType o) toExportType ty = case ty of BaseTy (Scalar sbt) -> return $ ScalarType sbt - NewtypeTyCon Nat -> return $ ScalarType IdxRepScalarBaseTy + TyCon (NewtypeTyCon Nat) -> return $ ScalarType IdxRepScalarBaseTy TabTy _ _ _ -> parseTabTy ty >>= \case Nothing -> unsupported Just ety -> return ety _ -> unsupported - where unsupported = throw TypeErr $ "Unsupported type of argument in exported function: " ++ pprint ty + where unsupported = throwErr $ MiscErr $ MiscMiscErr $ "Unsupported type of argument in exported function: " ++ pprint ty {-# INLINE toExportType #-} parseTabTy :: IRRep r => Type r i -> ExportSigM r i o (Maybe (ExportType o)) parseTabTy = go [] where - go :: forall r i o. IRRep r => [ExportDim o] -> Type r i - -> ExportSigM r i o (Maybe (ExportType o)) + go :: IRRep r => [ExportDim o] -> Type r i -> ExportSigM r i o (Maybe (ExportType o)) go shape = \case - BaseTy (Scalar sbt) -> return $ Just $ RectContArrayPtr sbt shape - NewtypeTyCon Nat -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape - TabTy d (b:>ixty) a -> do - maybeN <- case IxType ixty d of - IxType (NewtypeTyCon (Fin n)) _ -> return $ Just n - IxType _ (IxDictRawFin n) -> return $ Just n - _ -> return Nothing + TyCon (BaseType (Scalar sbt)) -> return $ Just $ RectContArrayPtr sbt shape + TyCon (NewtypeTyCon Nat) -> return $ Just $ RectContArrayPtr IdxRepScalarBaseTy shape + TyCon (TabPi (TabPiType d (b:>ixty) a)) -> do + maybeN <- fromIxFin $ IxType ixty d maybeDim <- case maybeN of - Just (Var v) -> do + Just (Stuck _ (Var v)) -> do s <- getSubst let (Rename v') = s ! atomVarName v return $ Just (ExportDimVar v') - Just (NewtypeCon NatCon (IdxRepVal s)) -> return $ Just (ExportDimLit $ fromIntegral s) + Just (Con (NewtypeCon NatCon (IdxRepVal s))) -> return $ Just (ExportDimLit $ fromIntegral s) Just (IdxRepVal s) -> return $ Just (ExportDimLit $ fromIntegral s) _ -> return Nothing case maybeDim of @@ -193,6 +191,12 @@ parseTabTy = go [] Nothing -> return Nothing _ -> return Nothing + fromIxFin :: IRRep r => IxType r i -> ExportSigM r i o (Maybe (Atom r i)) + fromIxFin = \case + IxType (TyCon (NewtypeTyCon (Fin n))) (DictCon (IxFin _)) -> return $ Just n + IxType _ (DictCon (IxRawFin n)) -> return $ Just n + _ -> return Nothing + data ArgVisibility = ImplicitArg | ExplicitArg data ExportDim n = diff --git a/src/lib/Generalize.hs b/src/lib/Generalize.hs index 58c0721d4..fc120af2e 100644 --- a/src/lib/Generalize.hs +++ b/src/lib/Generalize.hs @@ -7,29 +7,30 @@ module Generalize (generalizeArgs, generalizeIxDict) where import Control.Monad +import Data.Maybe (fromJust) import Core import Err +import PPrint import Types.Core import Inference import IRVariants import QueryType import Name import Subst -import MTL1 import Types.Primitives +import Types.Top type RolePiBinder = WithAttrB RoleExpl CBinder type RolePiBinders = Nest RolePiBinder -generalizeIxDict :: EnvReader m => Atom CoreIR n -> m n (Generalized CoreIR CAtom n) +generalizeIxDict :: EnvReader m => CDict n -> m n (Generalized CoreIR CDict n) generalizeIxDict dict = liftGeneralizerM do dict' <- sinkM dict dictTy <- return $ getType dict' dictTyGeneralized <- generalizeType dictTy - dictGeneralized <- liftEnvReaderM $ generalizeDict dictTyGeneralized dict' - return dictGeneralized --- {-# INLINE generalizeIxDict #-} + liftEnvReaderM $ generalizeDict dictTyGeneralized dict' +{-# INLINE generalizeIxDict #-} generalizeArgs ::EnvReader m => CorePiType n -> [Atom CoreIR n] -> m n (Generalized CoreIR (ListE CAtom) n) generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do @@ -40,13 +41,12 @@ generalizeArgs fTy argsTop = liftGeneralizerM $ runSubstReaderT idSubst do -> SubstReaderT AtomSubstVal GeneralizerM i n [Atom CoreIR n] go (Nest (WithAttrB expl b) bs) (arg:args) = do ty' <- substM $ binderType b - arg' <- case (ty', expl) of - (TyKind, _) -> liftSubstReaderT case arg of - Type t -> Type <$> generalizeType t - _ -> error "not a type" - (DictTy _, Inferred Nothing (Synth _)) -> generalizeDict ty' arg + arg' <- liftSubstReaderT case (ty', expl) of + (TyKind, _) -> toAtom <$> generalizeType (fromJust $ toMaybeType arg) + (TyCon (DictTy _), Inferred Nothing (Synth _)) -> + toAtom <$> generalizeDict ty' (fromJust $ toMaybeDict arg) _ -> isData ty' >>= \case - True -> liftM Var $ liftSubstReaderT $ emitGeneralizationParameter ty' arg + True -> toAtom <$> emitGeneralizationParameter ty' arg False -> do -- Unlike in `inferRoles` in `Inference`, it's ok to have non-data, -- non-type, non-dict arguments (e.g. a function). We just don't @@ -108,11 +108,9 @@ emitGeneralizationParameter ty val = GeneralizerM do -- Given a type (an Atom of type `Type`), abstracts over all data components generalizeType :: Type CoreIR n -> GeneralizerM n (Type CoreIR n) generalizeType ty = traverseTyParams ty \paramRole paramReqTy param -> case paramRole of - TypeParam -> Type <$> case param of - Type t -> generalizeType t - _ -> error "not a type" - DictParam -> generalizeDict paramReqTy param - DataParam -> Var <$> emitGeneralizationParameter paramReqTy param + TypeParam -> toAtom <$> generalizeType (fromJust $ toMaybeType param) + DictParam -> toAtom <$> generalizeDict paramReqTy (fromJust $ toMaybeDict param) + DataParam -> toAtom <$> emitGeneralizationParameter paramReqTy param -- === role-aware type traversal === @@ -125,27 +123,28 @@ traverseTyParams => CType n -> (forall l . DExt n l => ParamRole -> CType l -> CAtom l -> m l (CAtom l)) -> m n (CType n) -traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of - DictTy (DictType sn name params) -> do - Abs paramRoles UnitE <- getClassRoleBinders name - params' <- traverseRoleBinders f paramRoles params - return $ DictTy $ DictType sn name params' - TabPi (TabPiType (IxDictAtom d) (b:>iTy) resultTy) -> do +traverseTyParams (StuckTy _ _) _ = error "shouldn't have StuckTy left" +traverseTyParams (TyCon ty) f = liftM TyCon $ getDistinct >>= \Distinct -> case ty of + DictTy dictTy -> DictTy <$> case dictTy of + DictType sn name params -> do + Abs paramRoles UnitE <- getClassRoleBinders name + params' <- traverseRoleBinders f paramRoles params + return $ DictType sn name params' + IxDictType t -> IxDictType <$> f' TypeParam TyKind t + DataDictType t -> DataDictType <$> f' TypeParam TyKind t + TabPi (TabPiType d (b:>iTy) resultTy) -> do iTy' <- f' TypeParam TyKind iTy - dictTy <- liftM ignoreExcept $ runFallibleT1 $ DictTy <$> ixDictType iTy' - d' <- f DictParam dictTy d + let dictTy = toType $ IxDictType iTy' + d' <- fromJust . toMaybeDict <$> f DictParam dictTy (toAtom d) withFreshBinder (getNameHint b) iTy' \(b':>_) -> do resultTy' <- applyRename (b@>binderName b') resultTy >>= (f' TypeParam TyKind) - return $ TabTy (IxDictAtom d') (b':>iTy') resultTy' - -- shouldn't need this once we can exclude IxDictFin and IxDictSpecialized from CoreI - TabPi t -> return $ TabPi t - TC tc -> TC <$> case tc of - BaseType b -> return $ BaseType b - ProdType tys -> ProdType <$> forM tys \t -> f' TypeParam TyKind t - RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter? - SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t - TypeKind -> return TypeKind - HeapType -> return HeapType + return $ TabPi $ TabPiType d' (b':>iTy') resultTy' + BaseType b -> return $ BaseType b + ProdType tys -> ProdType <$> forM tys \t -> f' TypeParam TyKind t + RefType _ _ -> error "not implemented" -- how should we handle the ParamRole for the heap parameter? + SumType tys -> SumType <$> forM tys \t -> f' TypeParam TyKind t + TypeKind -> return TypeKind + HeapType -> return HeapType NewtypeTyCon con -> NewtypeTyCon <$> case con of Nat -> return Nat Fin n -> Fin <$> f DataParam NatTy n @@ -157,11 +156,7 @@ traverseTyParams ty f = getDistinct >>= \Distinct -> case ty of _ -> error $ "Not implemented: " ++ pprint ty where f' :: forall l . DExt n l => ParamRole -> CType l -> CType l -> m l (CType l) - f' r t x = fromType <$> f r t (Type x) - - fromType :: CAtom l -> CType l - fromType (Type t) = t - fromType x = error $ "not a type: " ++ pprint x + f' r t x = fromJust <$> toMaybeType <$> f r t (toAtom x) {-# INLINE traverseTyParams #-} traverseRoleBinders @@ -191,7 +186,7 @@ getDataDefRoleBinders def = do getClassRoleBinders :: EnvReader m => ClassName n -> m n (Abs RolePiBinders UnitE n) getClassRoleBinders def = do - ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef def + ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef def return $ Abs (zipAttrs roleExpls bs) UnitE {-# INLINE getClassRoleBinders #-} diff --git a/src/lib/Imp.hs b/src/lib/Imp.hs index d74333f38..18f3f7004 100644 --- a/src/lib/Imp.hs +++ b/src/lib/Imp.hs @@ -9,8 +9,7 @@ {-# OPTIONS_GHC -Wno-orphans #-} module Imp - ( toImpFunction - , impFunType, getIType, abstractLinktimeObjects + ( toImpFunction, repValAtom, impFunType, getIType, abstractLinktimeObjects , repValFromFlatList, addImpTracing -- These are just for the benefit of serialization/printing. otherwise we wouldn't need them , BufferType (..), IdxNest, IndexStructure, IExprInterpretation (..), typeToTree @@ -25,7 +24,6 @@ import Data.Maybe (fromJust, isJust) import Data.Text.Prettyprint.Doc import Control.Category import Control.Monad.Identity -import Control.Monad.Reader import Control.Monad.Writer.Strict import Control.Monad.State.Strict hiding (State) import qualified Control.Monad.State.Strict as MTL @@ -39,15 +37,16 @@ import Err import IRVariants import MTL1 import Name +import PPrint import Subst import QueryType import Types.Core import Types.Imp import Types.Primitives +import Types.Top import Util (forMFilter, Tree (..), zipTrees, enumerate) -toImpFunction :: EnvReader m - => CallingConvention -> STopLam n -> m n (ImpFunction n) +toImpFunction :: EnvReader m => CallingConvention -> STopLam n -> m n (ImpFunction n) toImpFunction cc (TopLam True destTy lam) = do LamExpr bsAndRefB body <- return lam PairB bs destB <- case popNest bsAndRefB of @@ -64,14 +63,14 @@ toImpFunction cc (TopLam True destTy lam) = do RefTy _ ansTy -> allocDestUnmanaged =<< substM ansTy _ -> error "Expected a reference type for body destination" extendSubst (destB @> SubstVal (destToAtom dest)) do - void $ translateBlock body + void $ translateExpr body resultAtom <- loadAtom dest repValToList <$> atomToRepVal resultAtom _ -> do (argAtoms, resultDest) <- interpretImpArgsWithCC cc (sink ty) vs extendSubst (bs @@> (SubstVal <$> argAtoms)) do extendSubst (destB @> SubstVal (destToAtom (sink resultDest))) do - void $ translateBlock body + void $ translateExpr body return [] toImpFunction _ (TopLam False _ _) = error "expected a lambda in destination-passing form" {-# SCC toImpFunction #-} @@ -246,14 +245,14 @@ instance ImpBuilder ImpM where {-# INLINE extendAllocsToFree #-} instance ImpBuilder m => ImpBuilder (SubstReaderT AtomSubstVal m i) where - emitMultiReturnInstr instr = SubstReaderT $ lift $ emitMultiReturnInstr instr + emitMultiReturnInstr instr = liftSubstReaderT $ emitMultiReturnInstr instr {-# INLINE emitMultiReturnInstr #-} - emitDeclsImp ab = SubstReaderT $ lift $ emitDeclsImp ab + emitDeclsImp ab = liftSubstReaderT $ emitDeclsImp ab {-# INLINE emitDeclsImp #-} - buildScopedImp cont = SubstReaderT $ ReaderT \env -> + buildScopedImp cont = SubstReaderT \env -> buildScopedImp $ runSubstReaderT (sink env) $ cont {-# INLINE buildScopedImp #-} - extendAllocsToFree ptr = SubstReaderT $ lift $ extendAllocsToFree ptr + extendAllocsToFree ptr = liftSubstReaderT $ extendAllocsToFree ptr {-# INLINE extendAllocsToFree #-} instance ImpBuilder m => Imper (SubstReaderT AtomSubstVal m) @@ -269,29 +268,17 @@ liftImpM cont = do -- === the actual pass === -translateBlock :: forall i o. Emits o - => SBlock i -> SubstImpM i o (SAtom o) -translateBlock (Abs decls result) = translateDeclNest decls $ substM result - -translateDeclNestSubst - :: Emits o => Subst AtomSubstVal l o - -> Nest SDecl l i' -> SubstImpM i o (Subst AtomSubstVal i' o) -translateDeclNestSubst !s = \case - Empty -> return s +translateDeclNest :: Emits o => Nest SDecl i i' -> SubstImpM i' o a -> SubstImpM i o a +translateDeclNest decls cont = case decls of + Empty -> cont Nest (Let b (DeclBinding _ expr)) rest -> do - x <- withSubst s $ translateExpr expr - translateDeclNestSubst (s <>> (b@>SubstVal x)) rest - -translateDeclNest :: Emits o - => Nest SDecl i i' -> SubstImpM i' o a -> SubstImpM i o a -translateDeclNest decls cont = do - s <- getSubst - s' <- translateDeclNestSubst s decls - withSubst s' cont + x <- translateExpr expr + extendSubst (b@>SubstVal x) $ translateDeclNest rest cont {-# INLINE translateDeclNest #-} translateExpr :: forall i o. Emits o => SExpr i -> SubstImpM i o (SAtom o) translateExpr expr = confuseGHC >>= \_ -> case expr of + Block _ (Abs decls result) -> translateDeclNest decls $ translateExpr result TopApp (EffTy _ resultTy') f' xs' -> do resultTy <- substM resultTy' f <- substM f' @@ -302,10 +289,10 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of scalarArgs <- liftM toList $ mapM fromScalarAtom xs results <- impCall f scalarArgs restructureScalarOrPairType resultTy results - TabApp _ f' xs' -> do - xs <- mapM substM xs' + TabApp _ f' x' -> do + x <- substM x' f <- atomToRepVal =<< substM f' - repValAtom =<< naryIndexRepVal f (toList xs) + repValAtom =<< indexRepVal f x Atom x -> substM x PrimOp op -> toImpOp op Case e alts (EffTy _ unitResultTy) -> do @@ -313,11 +300,12 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of case unitResultTy of UnitTy -> return () _ -> error $ "Unexpected returning Case in Imp " ++ pprint expr - case trySelectBranch e' of - Just (con, arg) -> do - Abs b body <- return $ alts !! con - extendSubst (b @> SubstVal arg) $ translateBlock body - Nothing -> do + case e' of + Con con -> do + SumCon _ i arg <- return con + Abs b body <- return $ alts !! i + extendSubst (b @> SubstVal arg) $ translateExpr body + Stuck _ _ -> do RepVal sumTy (Branch (tag:xss)) <- atomToRepVal e' ts <- caseAltsBinderTys sumTy tag' <- repValAtom $ RepVal TagRepTy tag @@ -329,9 +317,10 @@ translateExpr expr = confuseGHC >>= \_ -> case expr of emitSwitch tag' (zip xss alts) $ \(xs, Abs b body) -> extendSubst (b @> SubstVal (sink xs)) $ - void $ translateBlock body + void $ translateExpr body return UnitVal TabCon _ _ _ -> error "Unexpected `TabCon` in Imp pass." + Project _ i x -> reduceProj i =<< substM x toImpRefOp :: Emits o => SAtom i -> RefOp SimpIR i -> SubstImpM i o (SAtom o) @@ -365,10 +354,10 @@ toImpRefOp refDest' m = do True -> do BinaryLamExpr xb yb body <- return bc body' <- applySubst (xb @> SubstVal x <.> yb @> SubstVal y) body - ans <- liftBuilderImp $ emitBlock (sink body') + ans <- liftBuilderImp $ emit (sink body') storeAtom accDest ans False -> case accTy of - TabPi t -> do + TyCon (TabPi t) -> do let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do @@ -394,12 +383,12 @@ toImpOp op = case op of emitLoop (getNameHint b) d n \i -> do idx <- unsafeFromOrdinalImp (sink ixTy) i void $ extendSubst (b @> SubstVal (PairVal idx (sink carry'))) $ - translateBlock body + translateExpr body return carry' RememberDest _ d f -> do UnaryLamExpr b body <- return f d' <- substM d - void $ extendSubst (b @> SubstVal d') $ translateBlock body + void $ extendSubst (b @> SubstVal d') $ translateExpr body return d' Place ref val -> do val' <- substM val @@ -444,7 +433,7 @@ castPtrToVectorType ptr vty = do let PtrType (addrSpace, _) = getIType ptr cast ptr (PtrType (addrSpace, vty)) -toImpMiscOp :: Emits o => MiscOp SimpIR o -> SubstImpM i o (SAtom o) +toImpMiscOp :: forall i o . Emits o => MiscOp SimpIR o -> SubstImpM i o (SAtom o) toImpMiscOp op = case op of ThrowError resultTy -> do emitStatement IThrowError @@ -471,15 +460,14 @@ toImpMiscOp op = case op of returnIExprVal =<< emitInstr =<< (ISelect <$> fsa p <*> fsa x <*> fsa y) SumTag con -> case con of Con (SumCon _ tag _) -> return $ TagRepVal $ fromIntegral tag - RepValAtom dRepVal -> go dRepVal + Stuck _ (RepValAtom dRepVal) -> do + RepVal _ (Branch (tag:_)) <- return dRepVal + return $ toAtom $ RepVal (TagRepTy :: SType o) tag _ -> error $ "Not a data constructor: " ++ pprint con - where go dRepVal = do - RepVal _ (Branch (tag:_)) <- return dRepVal - return $ RepValAtom $ RepVal TagRepTy tag ToEnum ty i -> case ty of - SumTy cases -> do + TyCon (SumType cases) -> do i' <- fromScalarAtom i - return $ RepValAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases + return $ toAtom $ RepVal ty $ Branch $ Leaf i' : map (const (Branch [])) cases _ -> error $ "Not an enum: " ++ pprint ty OutputStream -> returnIExprVal =<< emitInstr IOutputStream ThrowException _ -> error "shouldn't have ThrowException left" -- also, should be replaced with user-defined errors @@ -532,7 +520,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do For _ _ _ -> error $ "Unexpected `for` in Imp pass " ++ pprint hof While body -> do body' <- buildBlockImp do - ans <- fromScalarAtom =<< translateBlock body + ans <- fromScalarAtom =<< translateExpr body return [ans] emitStatement $ IWhile body' return UnitVal @@ -542,7 +530,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do rDest <- allocDest $ getType r' storeAtom rDest r' extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom rDest)) $ - translateBlock body + translateExpr body RunWriter d (BaseMonoid e _) f -> do BinaryLamExpr h ref body <- return f let PairTy ansTy accTy = resultTy @@ -556,7 +544,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do PairE accTy' e'' <- sinkM $ PairE accTy e' liftMonoidEmpty wDest accTy' e'' extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom wDest)) $ - translateBlock body >>= storeAtom aDest + translateExpr body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom wDest RunState d s f -> do BinaryLamExpr h ref body <- return f @@ -569,10 +557,10 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do return (aDest, sDest) storeAtom sDest =<< substM s extendSubst (h @> SubstVal (Con HeapVal) <.> ref @> SubstVal (destToAtom sDest)) $ - translateBlock body >>= storeAtom aDest + translateExpr body >>= storeAtom aDest PairVal <$> loadAtom aDest <*> loadAtom sDest - RunIO body -> translateBlock body - RunInit body -> translateBlock body + RunIO body -> translateExpr body + RunInit body -> translateExpr body where liftMonoidEmpty :: Emits n => Dest n -> SType n -> SAtom n -> SubstImpM i n () liftMonoidEmpty accDest accTy x = do @@ -580,7 +568,7 @@ toImpTypedHof (TypedHof (EffTy _ resultTy') hof) = do alphaEq xTy accTy >>= \case True -> storeAtom accDest x False -> case accTy of - TabPi t -> do + TyCon (TabPi t) -> do let ixTy = tabIxType t n <- indexSetSizeImp ixTy emitLoop noHint Fwd n \i -> do @@ -696,28 +684,27 @@ typeToTree :: EnvReader m => SType n -> m n (Tree (LeafType n)) typeToTree tyTop = return $ go REmpty tyTop where go :: RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (LeafType n) - go ctx = \case - BaseTy b -> Leaf $ LeafType (unRNest ctx) b - TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy - RefTy _ t -> go (RNest ctx RefCtx) t + go ctx (TyCon con) = case con of + BaseType b -> Leaf $ LeafType (unRNest ctx) b + TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy + RefType _ t -> go (RNest ctx RefCtx) t DepPairTy (DepPairType _ (b:>t1) (t2)) -> do let tree1 = rec t1 let tree2 = go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 Branch [tree1, tree2] - ProdTy ts -> Branch $ map rec ts - SumTy ts -> do + ProdType ts -> Branch $ map rec ts + SumType ts -> do let tag = rec TagRepTy let xs = map rec ts Branch $ tag:xs - TC HeapType -> Branch [] - ty -> error $ "not implemented " ++ pprint ty + HeapType -> Branch [] where rec = go ctx traverseScalarRepTys :: EnvReader m => SType n -> (LeafType n -> m n a) -> m n (Tree a) traverseScalarRepTys ty f = traverse f =<< typeToTree ty {-# INLINE traverseScalarRepTys #-} -storeRepVal :: Emits n => Dest n -> SRepVal n -> SubstImpM i n () +storeRepVal :: Emits n => Dest n -> RepVal n -> SubstImpM i n () storeRepVal (Dest _ destTree) repVal@(RepVal _ valTree) = do leafTys <- valueToTree repVal forM_ (zipTrees (zipTrees leafTys destTree) valTree) \((leafTy, ptr), val) -> do @@ -726,16 +713,16 @@ storeRepVal (Dest _ destTree) repVal@(RepVal _ valTree) = do -- Like `typeToTree`, but when we additionally have the value, we can populate -- the existentially-hidden fields. -valueToTree :: EnvReader m => SRepVal n -> m n (Tree (LeafType n)) +valueToTree :: EnvReader m => RepVal n -> m n (Tree (LeafType n)) valueToTree (RepVal tyTop valTop) = do go REmpty tyTop valTop where go :: EnvReader m => RNest (TypeCtxLayer SimpIR) n l -> SType l -> Tree (IExpr n) -> m n (Tree (LeafType n)) - go ctx ty val = case ty of - BaseTy b -> return $ Leaf $ LeafType (unRNest ctx) b - TabTy d b bodyTy -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val - RefTy _ t -> go (RNest ctx RefCtx) t val + go ctx (TyCon ty) val = case ty of + BaseType b -> return $ Leaf $ LeafType (unRNest ctx) b + TabPi (TabPiType d b bodyTy) -> go (RNest ctx (TabCtx (PairB (LiftB d) b))) bodyTy val + RefType _ t -> go (RNest ctx RefCtx) t val DepPairTy (DepPairType _ (b:>t1) (t2)) -> case val of Branch [v1, v2] -> do case allDepPairCtxs (unRNest ctx) of @@ -750,10 +737,10 @@ valueToTree (RepVal tyTop valTop) = do tree2 <- go (RNest ctx (DepPairCtx (JustB (b:>t1)))) t2 v2 return $ Branch [tree1, tree2] _ -> error "expected a branch" - ProdTy ts -> case val of + ProdType ts -> case val of Branch vals -> Branch <$> zipWithM rec ts vals _ -> error "expected a branch" - SumTy ts -> case val of + SumType ts -> case val of Branch (tagVal:vals) -> do tag <- rec TagRepTy tagVal results <- zipWithM rec ts vals @@ -844,7 +831,7 @@ isNull p = do nullPtrIExpr :: BaseType -> IExpr n nullPtrIExpr baseTy = ILit $ PtrLit (CPU, baseTy) NullPtr -loadRepVal :: (ImpBuilder m, Emits n) => Dest n -> m n (SRepVal n) +loadRepVal :: (ImpBuilder m, Emits n) => Dest n -> m n (RepVal n) loadRepVal (Dest valTy destTree) = do leafTys <- typeToTree valTy RepVal valTy <$> forM (zipTrees leafTys destTree) \(leafTy, ptr) -> do @@ -854,60 +841,55 @@ loadRepVal (Dest valTy destTree) = do _ -> return ptr {-# INLINE loadRepVal #-} -atomToRepVal :: Emits n => SAtom n -> SubstImpM i n (SRepVal n) +atomToRepVal :: Emits n => SAtom n -> SubstImpM i n (RepVal n) atomToRepVal x = RepVal (getType x) <$> go x where go :: Emits n => SAtom n -> SubstImpM i n (Tree (IExpr n)) - go atom = case atom of - RepValAtom dRepVal -> do - (RepVal _ tree) <- return dRepVal - return tree + go (Con con) = case con of DepPair lhs rhs _ -> do lhsTree <- go lhs rhsTree <- go rhs return $ Branch [lhsTree, rhsTree] - Con (Lit l) -> return $ Leaf $ ILit l - Con (ProdCon xs) -> Branch <$> mapM go xs - Con (SumCon cases tag payload) -> do + Lit l -> return $ Leaf $ ILit l + ProdCon xs -> Branch <$> mapM go xs + SumCon cases tag payload -> do tag' <- go $ TagRepVal $ fromIntegral tag xs <- forM (enumerate cases) \(i, t) -> if i == tag then go payload - else buildGarbageVal t <&> \(RepValAtom (RepVal _ tree)) -> tree + else buildGarbageVal t <&> \(Stuck _ (RepValAtom (RepVal _ tree))) -> tree return $ Branch $ tag':xs - Con HeapVal -> return $ Branch [] + HeapVal -> return $ Branch [] + go (Stuck _ stuck) = case stuck of Var v -> lookupAtomName (atomVarName v) >>= \case TopDataBound (RepVal _ tree) -> return tree _ -> error "should only have pointer and data atom names left" PtrVar ty p -> return $ Leaf $ IPtrVar p ty - ProjectElt _ p val -> do - (ps, v) <- return $ asNaryProj p val - lookupAtomName (atomVarName v) >>= \case - TopDataBound (RepVal _ tree) -> applyProjection (toList ps) tree - _ -> error "should only be projecting a data atom" - where - applyProjection :: [Projection] -> Tree (IExpr n) -> SubstImpM i n (Tree (IExpr n)) - applyProjection [] t = return t - applyProjection (i:is) t = do - t' <- applyProjection is t - case i of - UnwrapNewtype -> error "impossible" - ProjectProduct idx -> case t' of - Branch ts -> return $ ts !! idx - _ -> error "should only be projecting a branch" + RepValAtom dRepVal -> do + (RepVal _ tree) <- return dRepVal + return tree + -- TODO: I think we want to be able to rule this one out by insisting that + -- RepValAtom is itself part of Stuck and it can't represent a product. + StuckProject i val -> do + Branch ts <- go =<< mkStuck val + return $ ts !! i + StuckTabApp f x' -> do + f' <- atomToRepVal =<< mkStuck f + RepVal _ t <- indexRepVal f' x' + return t -- XXX: We used to have a function called `destToAtom` which loaded the value -- from the dest. This version is not that. It just lifts a dest into an atom of -- type `Ref _`. destToAtom :: Dest n -> SAtom n -destToAtom (Dest valTy tree) = RepValAtom $ RepVal (RefTy (Con HeapVal) valTy) tree +destToAtom (Dest valTy tree) = toAtom $ RepVal (RefTy (Con HeapVal) valTy) tree atomToDest :: EnvReader m => SAtom n -> m n (Dest n) -atomToDest (RepValAtom val) = do +atomToDest (Stuck _ (RepValAtom val)) = do (RepVal ~(RefTy _ valTy) valTree) <- return val return $ Dest valTy valTree atomToDest atom = error $ "Expected a non-var atom of type `RawRef _`, got: " ++ pprint atom {-# INLINE atomToDest #-} -repValToList :: SRepVal n -> [IExpr n] +repValToList :: RepVal n -> [IExpr n] repValToList (RepVal _ tree) = toList tree -- TODO: augment with device, backend information as needed @@ -980,7 +962,7 @@ storeAtom dest x = storeRepVal dest =<< atomToRepVal x loadAtom :: Emits n => Dest n -> SubstImpM i n (SAtom n) loadAtom d = repValAtom =<< loadRepVal d -repValFromFlatList :: (TopBuilder m, Mut n) => SType n -> [LitVal] -> m n (SRepVal n) +repValFromFlatList :: (TopBuilder m, Mut n) => SType n -> [LitVal] -> m n (RepVal n) repValFromFlatList ty xs = do (litValTree, []) <- runStreamReaderT1 xs $ traverseScalarRepTys ty \_ -> fromJust <$> readStream @@ -996,7 +978,7 @@ litValToIExpr litval = case litval of buildGarbageVal :: Emits n => SType n -> SubstImpM i n (SAtom n) buildGarbageVal ty = - RepValAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do + toAtom <$> RepVal ty <$> traverseScalarRepTys ty \leafTy -> do case getIExprInterpretation leafTy of BufferPtr bufferTy -> allocBuffer Managed bufferTy RawValue b -> return $ ILit $ emptyLit b @@ -1004,10 +986,10 @@ buildGarbageVal ty = -- === Operations on dests === indexDest :: Emits n => Dest n -> SAtom n -> SubstImpM i n (Dest n) -indexDest (Dest (TabPi tabTy) tree) i = do +indexDest (Dest (TyCon (TabPi tabTy)) tree) i = do eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i - leafTys <- typeToTree $ TabPi tabTy + leafTys <- typeToTree $ toType tabTy Dest eltTy <$> forM (zipTrees leafTys tree) \(leafTy, ptr) -> do BufferType ixStruct _ <- return $ getRefBufferType leafTy offset <- computeOffsetImp ixStruct ord @@ -1015,23 +997,15 @@ indexDest (Dest (TabPi tabTy) tree) i = do indexDest _ _ = error "expected a reference to a table" {-# INLINE indexDest #-} --- TODO: direct n-ary version for efficiency? -naryIndexRepVal :: Emits n => RepVal SimpIR n -> [SAtom n] -> SubstImpM i n (RepVal SimpIR n) -naryIndexRepVal x [] = return x -naryIndexRepVal x (ix:ixs) = do - x' <- indexRepVal x ix - naryIndexRepVal x' ixs -{-# INLINE naryIndexRepVal #-} - -- TODO: de-dup with indexDest? indexRepValParam :: Emits n - => SRepVal n -> SAtom n -> (SType n -> SType n) + => RepVal n -> SAtom n -> (SType n -> SType n) -> (IExpr n -> SubstImpM i n (IExpr n)) - -> SubstImpM i n (SRepVal n) -indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do + -> SubstImpM i n (RepVal n) +indexRepValParam (RepVal (TyCon (TabPi tabTy)) vals) i tyFunc func = do eltTy <- instantiate tabTy [i] ord <- ordinalImp (tabIxType tabTy) i - leafTys <- typeToTree (TabPi tabTy) + leafTys <- typeToTree (toType tabTy) vals' <- forM (zipTrees leafTys vals) \(leafTy, ptr) -> do BufferPtr (BufferType ixStruct _) <- return $ getIExprInterpretation leafTy offset <- computeOffsetImp ixStruct ord @@ -1047,14 +1021,11 @@ indexRepValParam (RepVal (TabPi tabTy) vals) i tyFunc func = do indexRepValParam _ _ _ _ = error "expected table type" {-# INLINE indexRepValParam #-} -indexRepVal :: Emits n - => RepVal SimpIR n -> SAtom n -> SubstImpM i n (RepVal SimpIR n) +indexRepVal :: Emits n => RepVal n -> SAtom n -> SubstImpM i n (RepVal n) indexRepVal rep i = indexRepValParam rep i id return {-# INLINE indexRepVal #-} -vectorIndexRepVal :: Emits n - => RepVal SimpIR n -> SAtom n -> SType n - -> SubstImpM i n (RepVal SimpIR n) +vectorIndexRepVal :: Emits n => RepVal n -> SAtom n -> SType n -> SubstImpM i n (RepVal n) vectorIndexRepVal rep i vty = -- Passing `const vty` here depends on knowing that `vectorIndexRepVal` is -- only called on references of scalar base type, so that the give `vty` is, @@ -1064,7 +1035,7 @@ vectorIndexRepVal rep i vty = {-# INLINE vectorIndexRepVal #-} projectDest :: Int -> Dest n -> Dest n -projectDest i (Dest (ProdTy tys) (Branch ds)) = +projectDest i (Dest (TyCon (ProdType tys)) (Branch ds)) = Dest (tys!!i) (ds!!i) projectDest _ (Dest ty _) = error $ "Can't project dest: " ++ pprint ty @@ -1081,7 +1052,7 @@ type SBuilderM = BuilderM SimpIR computeElemCountImp :: Emits n => IndexStructure SimpIR n -> SubstImpM i n (IExpr n) computeElemCountImp Singleton = return $ IIdxRepVal 1 computeElemCountImp idxs = do - result <- coreToImpBuilder do + result <- liftBuilderImp do idxs' <- sinkM idxs computeElemCount idxs' fromScalarAtom result @@ -1090,7 +1061,7 @@ computeOffsetImp :: Emits n => IndexStructure SimpIR n -> IExpr n -> SubstImpM i n (IExpr n) computeOffsetImp idxs ixOrd = do let ixOrd' = toScalarAtom ixOrd - result <- coreToImpBuilder do + result <- liftBuilderImp do PairE idxs' ixOrd'' <- sinkM $ PairE idxs ixOrd' computeOffset idxs' ixOrd'' fromScalarAtom result @@ -1119,11 +1090,11 @@ elemCountPoly (Abs bs UnitE) = case bs of computeSizeGivenOrdinal :: EnvReader m => IxBinder SimpIR n l -> IndexStructure SimpIR l - -> m n (Abs (Binder SimpIR) (Block SimpIR) n) + -> m n (Abs SBinder SExpr n) computeSizeGivenOrdinal (PairB (LiftB d) (b:>t)) idxStruct = liftBuilder do withFreshBinder noHint IdxRepTy \bOrdinal -> Abs bOrdinal <$> buildBlock do - i <- unsafeFromOrdinal (sink $ IxType t d) $ Var $ sink $ binderVar bOrdinal + i <- unsafeFromOrdinal (sink $ IxType t d) $ toAtom $ sink $ binderVar bOrdinal idxStruct' <- applySubst (b@>SubstVal i) idxStruct elemCountPoly $ sink idxStruct' @@ -1151,8 +1122,8 @@ computeOffset (EmptyAbs (Nest b idxs)) idxOrdinal = do computeOffset _ _ = error "Expected a nonempty nest of idx binders" sumUsingPolysImp - :: Emits n => Atom SimpIR n - -> Abs (Binder SimpIR) (Block SimpIR) n -> BuilderM SimpIR n (SAtom n) + :: Emits n => SAtom n + -> Abs SBinder SExpr n -> BuilderM SimpIR n (SAtom n) sumUsingPolysImp lim (Abs i body) = do ab <- hoistDecls i body sumUsingPolys lim ab @@ -1160,30 +1131,31 @@ sumUsingPolysImp lim (Abs i body) = do hoistDecls :: ( Builder SimpIR m, EnvReader m, Emits n , BindsNames b, BindsEnv b, RenameB b, SinkableB b) - => b n l -> SBlock l -> m n (Abs b SBlock n) + => b n l -> SExpr l -> m n (Abs b SExpr n) hoistDecls b block = do emitDecls =<< liftEnvReaderM do - refreshAbs (Abs b block) \b' (Abs decls result) -> - hoistDeclsRec b' Empty decls result + refreshAbs (Abs b block) \b' body -> + hoistDeclsRec b' Empty body {-# INLINE hoistDecls #-} hoistDeclsRec :: (BindsNames b, SinkableB b) - => b n1 n2 -> SDecls n2 n3 -> SDecls n3 n4 -> SAtom n4 - -> EnvReaderM n3 (Abs SDecls (Abs b (Abs SDecls SAtom)) n1) -hoistDeclsRec b declsAbove Empty result = - return $ Abs Empty $ Abs b $ Abs declsAbove result -hoistDeclsRec b declsAbove (Nest decl declsBelow) result = do - let (Let _ expr) = decl - let exprIsPure = isPure expr - refreshAbs (Abs decl (Abs declsBelow result)) - \decl' (Abs declsBelow' result') -> - case exchangeBs (PairB (PairB b declsAbove) decl') of - HoistSuccess (PairB hoistedDecl (PairB b' declsAbove')) | exprIsPure -> do - Abs hoistedDecls fullResult <- hoistDeclsRec b' declsAbove' declsBelow' result' - return $ Abs (Nest hoistedDecl hoistedDecls) fullResult - _ -> hoistDeclsRec b (declsAbove >>> Nest decl' Empty) declsBelow' result' -{-# INLINE hoistDeclsRec #-} + => b n1 n2 -> SDecls n2 n3 -> SExpr n3 + -> EnvReaderM n3 (Abs SDecls (Abs b SExpr) n1) +hoistDeclsRec = undefined +-- hoistDeclsRec b declsAbove Empty result = +-- return $ Abs Empty $ Abs b $ Abs declsAbove result +-- hoistDeclsRec b declsAbove (Nest decl declsBelow) result = do +-- let (Let _ expr) = decl +-- let exprIsPure = isPure expr +-- refreshAbs (Abs decl (Abs declsBelow result)) +-- \decl' (Abs declsBelow' result') -> +-- case exchangeBs (PairB (PairB b declsAbove) decl') of +-- HoistSuccess (PairB hoistedDecl (PairB b' declsAbove')) | exprIsPure -> do +-- Abs hoistedDecls fullResult <- hoistDeclsRec b' declsAbove' declsBelow' result' +-- return $ Abs (Nest hoistedDecl hoistedDecls) fullResult +-- _ -> hoistDeclsRec b (declsAbove >>> Nest decl' Empty) declsBelow' result' +-- {-# INLINE hoistDeclsRec #-} -- === Imp IR builder === @@ -1235,7 +1207,7 @@ buildImpFunction cc argHintsTys body = do return $ ImpFunction impFun $ Abs bs $ ImpBlock decls results buildImpNaryAbs - :: (SinkableE e, HasNamesE e, RenameE e, HoistableE e) + :: HasNamesE e => [(NameHint, IType)] -> (forall l. (Emits l, DExt n l) => [(Name ImpNameC l, BaseType)] -> SubstImpM i l (e l)) -> SubstImpM i n (Abs (Nest IBinder) (Abs (Nest ImpDecl) e) n) @@ -1376,11 +1348,9 @@ fromScalarAtom atom = atomToRepVal atom >>= \case Leaf x -> return x _ -> error $ "Not a scalar atom:" ++ pprint ty -toScalarAtom :: IExpr n -> SAtom n -toScalarAtom x = RepValAtom $ RepVal (BaseTy (getIType x)) (Leaf x) +toScalarAtom :: forall n. IExpr n -> SAtom n +toScalarAtom x = toAtom $ RepVal (BaseTy (getIType x) :: SType n) (Leaf x) --- TODO: we shouldn't need the rank-2 type here because ImpBuilder and Builder --- are part of the same conspiracy. liftBuilderImp :: (Emits n, SubstE AtomSubstVal e, SinkableE e) => (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l)) -> SubstImpM i n (e n) @@ -1389,46 +1359,34 @@ liftBuilderImp cont = do dropSubst $ translateDeclNest decls $ substM result {-# INLINE liftBuilderImp #-} -coreToImpBuilder - :: (Emits n, ImpBuilder m, SinkableE e, RenameE e, SubstE AtomSubstVal e ) - => (forall l. (Emits l, DExt n l) => BuilderM SimpIR l (e l)) - -> m n (e n) -coreToImpBuilder cont = do - block <- liftBuilder $ buildScoped cont - result <- liftImpM $ buildScopedImp $ dropSubst do - Abs decls result <- sinkM block - translateDeclNest decls $ substM result - emitDeclsImp result -{-# INLINE coreToImpBuilder #-} - -- === Type classes === ordinalImp :: Emits n => IxType SimpIR n -> SAtom n -> SubstImpM i n (IExpr n) -ordinalImp (IxType _ dict) i = fromScalarAtom =<< case dict of - IxDictRawFin _ -> return i - IxDictSpecialized _ d params -> do +ordinalImp (IxType _ (DictCon dict)) i = fromScalarAtom =<< case dict of + IxRawFin _ -> return i + IxSpecialized d params -> do appSpecializedIxMethod d Ordinal (params ++ [i]) unsafeFromOrdinalImp :: Emits n => IxType SimpIR n -> IExpr n -> SubstImpM i n (SAtom n) -unsafeFromOrdinalImp (IxType _ dict) i = do +unsafeFromOrdinalImp (IxType _ (DictCon dict)) i = do let i' = toScalarAtom i case dict of - IxDictRawFin _ -> return i' - IxDictSpecialized _ d params -> + IxRawFin _ -> return i' + IxSpecialized d params -> appSpecializedIxMethod d UnsafeFromOrdinal (params ++ [i']) indexSetSizeImp :: Emits n => IxType SimpIR n -> SubstImpM i n (IExpr n) -indexSetSizeImp (IxType _ dict) = do +indexSetSizeImp (IxType _ (DictCon dict)) = do fromScalarAtom =<< case dict of - IxDictRawFin n -> return n - IxDictSpecialized _ d params -> + IxRawFin n -> return n + IxSpecialized d params -> appSpecializedIxMethod d Size (params ++ []) appSpecializedIxMethod :: Emits n => SpecDictName n -> IxMethod -> [SAtom n] -> SubstImpM i n (SAtom n) appSpecializedIxMethod d method args = do SpecializedDict _ (Just fs) <- lookupSpecDict d TopLam _ _ (LamExpr bs body) <- return $ fs !! fromEnum method - dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateBlock body + dropSubst $ extendSubst (bs @@> map SubstVal args) $ translateExpr body -- === Abstracting link-time objects === @@ -1466,10 +1424,10 @@ abstractLinktimeObjects f = do isSingletonType :: Type SimpIR n -> Bool isSingletonType topTy = isJust $ checkIsSingleton topTy where - checkIsSingleton :: Type r n -> Maybe () - checkIsSingleton ty = case ty of + checkIsSingleton :: SType n -> Maybe () + checkIsSingleton (TyCon ty) = case ty of TabPi (TabPiType _ _ body) -> checkIsSingleton body - TC (ProdType tys) -> mapM_ checkIsSingleton tys + ProdType tys -> mapM_ checkIsSingleton tys _ -> Nothing singletonTypeVal :: (EnvReader m) @@ -1479,7 +1437,7 @@ singletonTypeVal ty = do if length tree == 0 then do -- The tree has 0 of these if the type is empty let tree' = fmap (const $ ILit $ Int32Lit 0) tree - return $ Just $ RepValAtom $ RepVal ty tree' + Just <$> mkStuck (RepValAtom $ RepVal ty tree') else return Nothing {-# INLINE singletonTypeVal #-} diff --git a/src/lib/ImpToLLVM.hs b/src/lib/ImpToLLVM.hs index b747c2c2c..e5736b35e 100644 --- a/src/lib/ImpToLLVM.hs +++ b/src/lib/ImpToLLVM.hs @@ -45,19 +45,17 @@ import qualified Data.Set as S import CUDA (getCudaArchitecture) import Core -import Err import Imp import LLVM.CUDA (LLVMKernel (..), compileCUDAKernel, ptxDataLayout, ptxTargetTriple) -import Logging import Subst import Name import PPrint import RawName qualified as R import Types.Core import Types.Imp -import Types.Misc import Types.Primitives import Types.Source +import Types.Top import Util (IsBool (..), bindM2, enumerate) -- === Compile monad === @@ -95,7 +93,7 @@ newtype CompileM i o a = , EnvReader, SubstReader OperandSubstVal ) instance MonadState CompileState (CompileM i o) where - state f = CompileM $ SubstReaderT $ lift $ EnvReaderT $ lift $ state f + state f = CompileM $ liftSubstReaderT $ EnvReaderT $ lift $ state f class MonadState CompileState m => LLVMBuilder (m::MonadKind) @@ -110,7 +108,7 @@ instance Compiler CompileM -- === Imp to LLVM === impToLLVM :: (EnvReader m, MonadIO1 m) - => FilteredLogger PassName [Output] -> NameHint + => PassLogger -> NameHint -> ClosedImpFunction n -> m n (WithCNameInterface L.Module) impToLLVM logger fNameHint (ClosedImpFunction funBinders ptrBinders impFun) = do @@ -186,7 +184,7 @@ impToLLVM logger fNameHint (ClosedImpFunction funBinders ptrBinders impFun) = do compileFunction :: (EnvReader m, MonadIO1 m) - => FilteredLogger PassName [Output] -> L.Name + => PassLogger -> L.Name -> OperandEnv i o -> ImpFunction i -> m o ([L.Definition], S.Set ExternFunSpec, [L.Name]) compileFunction logger fName env fun@(ImpFunction (IFunType cc argTys retTys) @@ -311,7 +309,7 @@ compileInstr instr = case instr of compileIf p' (compileVoidBlock cons) (compileVoidBlock alt) IQueryParallelism f s -> do let IFunType cc _ _ = snd f - let kernelFuncName = topLevelFunName $ fst f + let kernelFuncName = topLevelFunName $ fromString $ fst f n <- (`asIntWidth` i64) =<< compileExpr s case cc of MCThreadLaunch -> do @@ -339,7 +337,7 @@ compileInstr instr = case instr of ILaunch (fname, IFunType cc _ _) size args -> [] <$ do size' <- (`asIntWidth` i64) =<< compileExpr size args' <- mapM compileExpr args - let kernelFuncName = topLevelFunName fname + let kernelFuncName = topLevelFunName (fromString fname) case cc of MCThreadLaunch -> do kernelParams <- packArgs args' @@ -508,11 +506,11 @@ compileInstr instr = case instr of let resultTys = map scalarTy impResultTys case cc of FFICC -> do - ans <- emitExternCall (makeFunSpec fname ty) args' + ans <- emitExternCall (makeFunSpec (fromString fname) ty) args' return [ans] FFIMultiResultCC -> do resultPtr <- makeMultiResultAlloc resultTys - emitVoidExternCall (makeFunSpec fname ty) (resultPtr : args') + emitVoidExternCall (makeFunSpec (fromString fname) ty) (resultPtr : args') loadMultiResultAlloc resultTys resultPtr _ -> error $ "Unsupported calling convention: " ++ show cc DebugPrint fmtStr x -> [] <$ do @@ -539,11 +537,11 @@ compileInstr instr = case instr of -- TODO: use a careful naming discipline rather than strings -- (this is only used on the CUDA path which is currently broken anyway) topLevelFunName :: SourceName -> L.Name -topLevelFunName name = fromString name +topLevelFunName name = fromString $ pprint name makeFunSpec :: SourceName -> IFunType -> ExternFunSpec makeFunSpec name impFunTy = - ExternFunSpec (L.Name (fromString name)) retTy [] [] argTys + ExternFunSpec (L.Name (fromString $ pprint name)) retTy [] [] argTys where (retTy, argTys) = impFunTyToLLVMTy impFunTy impFunTyToLLVMTy :: IFunType -> LLVMFunType diff --git a/src/lib/IncState.hs b/src/lib/IncState.hs new file mode 100644 index 000000000..e825bf1e3 --- /dev/null +++ b/src/lib/IncState.hs @@ -0,0 +1,141 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} + +module IncState ( + IncState (..), MapEltUpdate (..), MapUpdate (..), + Overwrite (..), TailUpdate (..), Unchanging (..), Overwritable (..), + mapUpdateMapWithKey, MonoidState (..)) where + +import Data.Aeson (ToJSON, ToJSONKey) +import qualified Data.Map.Strict as M +import GHC.Generics + +-- === IncState === + +class Monoid d => IncState s d where + applyDiff :: s -> d -> s + +-- === Diff utils === + +data MapEltUpdate s d = + Create s + | Replace s -- TODO: should we merge Create/Replace? + | Update d + | Delete + deriving (Functor, Show, Generic) + +data MapUpdate k s d = MapUpdate { mapUpdates :: M.Map k (MapEltUpdate s d) } + deriving (Functor, Show, Generic) + +mapUpdateMapWithKey :: MapUpdate k s d -> (k -> s -> s') -> (k -> d -> d') -> MapUpdate k s' d' +mapUpdateMapWithKey (MapUpdate m) fs fd = + MapUpdate $ flip M.mapWithKey m \k v -> case v of + Create s -> Create $ fs k s + Replace s -> Replace $ fs k s + Update d -> Update $ fd k d + Delete -> Delete + +instance (IncState s d, Ord k) => Monoid (MapUpdate k s d) where + mempty = MapUpdate mempty + +instance (IncState s d, Ord k) => Semigroup (MapUpdate k s d) where + MapUpdate m1 <> MapUpdate m2 = MapUpdate $ + M.mapMaybe id (M.intersectionWith combineElts m1 m2) + <> M.difference m1 m2 + <> M.difference m2 m1 + where combineElts e1 e2 = case e1 of + Create s -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Replace s' -> Just $ Create s' + Update d -> Just $ Create (applyDiff s d) + Delete -> Nothing + Replace s -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Replace s' -> Just $ Replace s' + Update d -> Just $ Replace (applyDiff s d) + Delete -> Nothing + Update d -> case e2 of + Create _ -> error "shouldn't be creating a node that already exists" + Replace s -> Just $ Replace s + Update d' -> Just $ Update (d <> d') + Delete -> Just $ Delete + Delete -> case e2 of + Create s -> Just $ Replace s + Replace _ -> error "shouldn't be replacing a node that doesn't exist" + Update _ -> error "shouldn't be updating a node that doesn't exist" + Delete -> error "shouldn't be deleting a node that doesn't exist" + +instance (IncState s d, Ord k) => IncState (M.Map k s) (MapUpdate k s d) where + applyDiff m (MapUpdate updates) = + M.mapMaybe id (M.intersectionWith applyEltUpdate m updates) + <> M.difference m updates + <> M.mapMaybe applyEltCreation (M.difference updates m) + where applyEltUpdate s = \case + Create _ -> error "key already exists" + Replace s' -> Just s' + Update d -> Just $ applyDiff s d + Delete -> Nothing + applyEltCreation = \case + Create s -> Just s + Replace _ -> error "key doesn't exist yet" + Update _ -> error "key doesn't exist yet" + Delete -> error "key doesn't exist yet" + +data TailUpdate a = TailUpdate + { numDropped :: Int + , newTail :: [a] } + deriving (Show, Generic) + +instance Semigroup (TailUpdate a) where + TailUpdate n1 xs1 <> TailUpdate n2 xs2 = + let xs1Rem = length xs1 - n2 in + if xs1Rem >= 0 + then TailUpdate n1 (take xs1Rem xs1 <> xs2) -- n2 clobbered by xs1 + else TailUpdate (n1 - xs1Rem) xs2 -- xs1 clobbered by n2 + +instance Monoid (TailUpdate a) where + mempty = TailUpdate 0 [] + +instance IncState [a] (TailUpdate a) where + applyDiff xs (TailUpdate numDrop ys) = take (length xs - numDrop) xs <> ys + +-- Trivial diff that works for any type - just replace the old value with a completely new one. +data Overwrite a = NoChange | OverwriteWith a + deriving (Show, Eq, Generic, Functor, Foldable, Traversable) +newtype Overwritable a = Overwritable { fromOverwritable :: a } deriving (Show, Eq, Ord) + +instance Semigroup (Overwrite a) where + l <> r = case r of + OverwriteWith r' -> OverwriteWith r' + NoChange -> l + +instance Monoid (Overwrite a) where + mempty = NoChange + +instance IncState (Overwritable a) (Overwrite a) where + applyDiff s = \case + NoChange -> s + OverwriteWith s' -> Overwritable s' + +-- Case when the diff and the state are the same +newtype MonoidState a = MonoidState a + +instance Monoid a => IncState (MonoidState a) a where + applyDiff (MonoidState d) d' = MonoidState $ d <> d' + + +-- Trivial diff that works for any type - just replace the old value with a completely new one. +newtype Unchanging a = Unchanging { fromUnchanging :: a } deriving (Show, Eq, Ord) + +instance IncState (Unchanging a) () where + applyDiff s () = s + +instance ToJSON a => ToJSON (Overwrite a) +instance (ToJSON s, ToJSON d, ToJSONKey k) => ToJSON (MapUpdate k s d) +instance ToJSON a => ToJSON (TailUpdate a) +instance (ToJSON s, ToJSON d) => ToJSON (MapEltUpdate s d) diff --git a/src/lib/Inference.hs b/src/lib/Inference.hs index cd28d7c5e..50c936f2c 100644 --- a/src/lib/Inference.hs +++ b/src/lib/Inference.hs @@ -4,33 +4,27 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -{-# LANGUAGE AllowAmbiguousTypes #-} {-# LANGUAGE UndecidableInstances #-} {-# OPTIONS_GHC -Wno-orphans #-} module Inference - ( inferTopUDecl, checkTopUType, inferTopUExpr - , trySynthTerm, generalizeDict, asTopBlock - , synthTopE, UDeclInferenceResult (..), asFFIFunType) where + ( inferTopUDecl, checkTopUType, inferTopUExpr , generalizeDict, asTopBlock + , UDeclInferenceResult (..), asFFIFunType) where import Prelude hiding ((.), id) import Control.Category import Control.Applicative import Control.Monad import Control.Monad.State.Strict -import Control.Monad.Writer.Strict hiding (Alt) import Control.Monad.Reader import Data.Either (partitionEithers) import Data.Foldable (toList, asum) import Data.Functor ((<&>)) import Data.List (sortOn) import Data.Maybe (fromJust, fromMaybe, catMaybes) -import Data.Text.Prettyprint.Doc (Pretty (..), (<+>), vcat, group, line, nest) import Data.Word import qualified Data.HashMap.Strict as HM import qualified Data.Map.Strict as M -import qualified Data.Set as S -import qualified Unsafe.Coerce as TrulyUnsafe import GHC.Generics (Generic (..)) import Builder @@ -39,27 +33,28 @@ import CheckType import Core import Err import IRVariants +import MonadUtil import MTL1 import Name -import SourceInfo import Subst +import PPrint import QueryType import Types.Core import Types.Imp import Types.Primitives import Types.Source +import Types.Top +import qualified Types.OpNames as P import Util hiding (group) -import PPrint (prettyBlock) -- === Top-level interface === -checkTopUType :: (Fallible1 m, EnvReader m) => UType n -> m n (CType n) -checkTopUType ty = liftInfererM $ solveLocal $ withApplyDefaults $ checkUType ty +checkTopUType :: (Fallible1 m, TopLogger m, EnvReader m) => UType n -> m n (CType n) +checkTopUType ty = liftInfererM $ checkUType ty {-# SCC checkTopUType #-} -inferTopUExpr :: (Fallible1 m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) -inferTopUExpr e = asTopBlock =<< liftInfererM do - solveLocal $ buildBlockInf $ withApplyDefaults $ inferSigma noHint e +inferTopUExpr :: (Fallible1 m, TopLogger m, EnvReader m) => UExpr n -> m n (TopBlock CoreIR n) +inferTopUExpr e = fst <$> (asTopBlock =<< liftInfererM (buildBlock $ bottomUp e)) {-# SCC inferTopUExpr #-} data UDeclInferenceResult e n = @@ -67,27 +62,24 @@ data UDeclInferenceResult e n = | UDeclResultBindName LetAnn (TopBlock CoreIR n) (Abs (UBinder (AtomNameC CoreIR)) e n) | UDeclResultBindPattern NameHint (TopBlock CoreIR n) (ReconAbs CoreIR e n) -inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, SinkableE e, HoistableE e, RenameE e) +type TopLogger (m::MonadKind1) = forall n. Logger Outputs (m n) + +inferTopUDecl :: (Mut n, Fallible1 m, TopBuilder m, HasNamesE e, TopLogger m) => UTopDecl n l -> e l -> m n (UDeclInferenceResult e n) inferTopUDecl (UStructDecl tc def) result = do tc' <- emitBinding (getNameHint tc) $ TyConBinding Nothing (DotMethods mempty) - def' <- liftInfererM $ solveLocal do - extendRenamer (tc@>sink tc') $ inferStructDef def - def'' <- synthTyConDef def' - updateTopEnv $ UpdateTyConDef tc' def'' - UStructDef _ (_, paramBs) _ methods <- return def + def' <- liftInfererM $ extendRenamer (tc@>tc') $ inferStructDef def + updateTopEnv $ UpdateTyConDef tc' def' + UStructDef _ paramBs _ methods <- return def forM_ methods \(letAnn, methodName, methodDef) -> do - method <- liftInfererM $ solveLocal $ - extendRenamer (tc@>sink tc') $ - inferDotMethod (sink tc') (Abs paramBs methodDef) - methodSynth <- synthTopE (Lam method) - method' <- emitTopLet (getNameHint methodName) letAnn (Atom methodSynth) + method <- liftInfererM $ extendRenamer (tc@>tc') $ + inferDotMethod tc' (Abs paramBs methodDef) + method' <- emitTopLet (getNameHint methodName) letAnn (Atom $ toAtom $ Lam method) updateTopEnv $ UpdateFieldDef tc' methodName (atomVarName method') UDeclResultDone <$> applyRename (tc @> tc') result inferTopUDecl (UDataDefDecl def tc dcs) result = do - tcDef <- liftInfererM $ solveLocal $ inferTyConDef def - tcDef'@(TyConDef _ _ _ (ADTCons dataCons)) <- synthTyConDef tcDef - tc' <- emitBinding (getNameHint tcDef') $ TyConBinding (Just tcDef') (DotMethods mempty) + tcDef@(TyConDef _ _ _ (ADTCons dataCons)) <- liftInfererM $ inferTyConDef def + tc' <- emitBinding (getNameHint tcDef) $ TyConBinding (Just tcDef) (DotMethods mempty) dcs' <- forM (enumerate dataCons) \(i, dcDef) -> emitBinding (getNameHint dcDef) $ DataConBinding tc' i let subst = tc @> tc' <.> dcs @@> dcs' @@ -95,994 +87,580 @@ inferTopUDecl (UDataDefDecl def tc dcs) result = do inferTopUDecl (UInterface paramBs methodTys className methodNames) result = do let sn = getSourceName className let methodSourceNames = nestToList getSourceName methodNames - classDef <- liftInfererM $ solveLocal $ inferClassDef sn methodSourceNames paramBs methodTys + classDef <- liftInfererM $ inferClassDef sn methodSourceNames paramBs methodTys className' <- emitBinding (getNameHint sn) $ ClassBinding classDef - methodNames' <- - forM (enumerate methodSourceNames) \(i, prettyName) -> do - emitBinding (getNameHint prettyName) $ MethodBinding className' i + methodNames' <- forM (enumerate methodSourceNames) \(i, prettyName) -> do + emitBinding (getNameHint prettyName) $ MethodBinding className' i let subst = className @> className' <.> methodNames @@> methodNames' UDeclResultDone <$> applyRename subst result -inferTopUDecl (UInstance className instanceBs params methods maybeName expl) result = do +inferTopUDecl (UInstance className bs params methods maybeName expl) result = do let (InternalName _ _ className') = className - ab <- liftInfererM $ solveLocal do - withRoleUBinders instanceBs do - ClassDef _ _ _ roleExpls paramBinders _ _ <- lookupClassDef (sink className') - let expls = snd <$> roleExpls - params' <- checkInstanceParams expls paramBinders params - className'' <- sinkM className' - body <- checkInstanceBody className'' params' methods - return (ListE params' `PairE` body) - Abs bs' (ListE params' `PairE` body) <- return ab - let (roleExpls, bs'') = unzipAttrs bs' - let def = InstanceDef className' roleExpls bs'' params' body + def <- liftInfererM $ withRoleUBinders bs \(ZipB roleExpls bs') -> do + ClassDef _ _ _ _ _ paramBinders _ _ <- lookupClassDef (sink className') + params' <- checkInstanceParams paramBinders params + body <- checkInstanceBody (sink className') params' methods + return $ InstanceDef className' roleExpls bs' params' body UDeclResultDone <$> case maybeName of RightB UnitB -> do - void $ synthInstanceDefAndAddSynthCandidate def + instanceName <- emitInstanceDef def + ClassDef _ builtinName _ _ _ _ _ _ <- lookupClassDef className' + addInstanceSynthCandidate className' builtinName instanceName return result JustB instanceName' -> do - def' <- synthInstanceDef def - instanceName <- emitInstanceDef def' + instanceName <- emitInstanceDef def lam <- instanceFun instanceName expl instanceAtomName <- emitTopLet (getNameHint instanceName') PlainLet $ Atom lam applyRename (instanceName' @> atomVarName instanceAtomName) result _ -> error "impossible" -inferTopUDecl (ULocalDecl (WithSrcB src decl)) result = addSrcContext src case decl of +inferTopUDecl (ULocalDecl (WithSrcB _ decl)) result = case decl of UPass -> return $ UDeclResultDone result UExprDecl _ -> error "Shouldn't have this at the top level (should have become a command instead)" ULet letAnn p tyAnn rhs -> case p of WithSrcB _ (UPatBinder b) -> do - block <- liftInfererM $ solveLocal $ buildBlockInf do - checkMaybeAnnExpr (getNameHint b) tyAnn rhs <* applyDefaults - topBlock <- asTopBlock block - return $ UDeclResultBindName letAnn topBlock (Abs b result) + block <- liftInfererM $ buildBlock do + checkMaybeAnnExpr tyAnn rhs + (topBlock, resultTy) <- asTopBlock block + let letAnn' = considerInlineAnn letAnn resultTy + return $ UDeclResultBindName letAnn' topBlock (Abs b result) _ -> do - PairE block recon <- liftInfererM $ solveLocal $ buildBlockInfWithRecon do - val <- checkMaybeAnnExpr (getNameHint p) tyAnn rhs - v <- emitHinted (getNameHint p) $ Atom val + PairE block recon <- liftInfererM $ buildBlockInfWithRecon do + val <- checkMaybeAnnExpr tyAnn rhs + v <- emitDecl (getNameHint p) PlainLet $ Atom val bindLetPat p v do - applyDefaults renameM result - topBlock <- asTopBlock block + (topBlock, _) <- asTopBlock block return $ UDeclResultBindPattern (getNameHint p) topBlock recon -inferTopUDecl (UEffectDecl _ _ _) _ = error "not implemented" -inferTopUDecl (UHandlerDecl _ _ _ _ _ _ _) _ = error "not implemented" {-# SCC inferTopUDecl #-} -asTopBlock :: EnvReader m => CBlock n -> m n (TopBlock CoreIR n) +asTopBlock :: EnvReader m => CExpr n -> m n (TopBlock CoreIR n, CType n) asTopBlock block = do - effTy <- blockEffTy block - return $ TopLam False (PiType Empty effTy) (LamExpr Empty block) + let effs = getEffects block + let ty = getType block + return (TopLam False (PiType Empty (EffTy effs ty)) (LamExpr Empty block), ty) getInstanceType :: EnvReader m => InstanceDef n -> m n (CorePiType n) getInstanceType (InstanceDef className roleExpls bs params _) = liftEnvReaderM do refreshAbs (Abs bs (ListE params)) \bs' (ListE params') -> do className' <- sinkM className - ClassDef classSourceName _ _ _ _ _ _ <- lookupClassDef className' - let dTy = DictTy $ DictType classSourceName className' params' + dTy <- toType <$> dictType className' params' return $ CorePiType ImplicitApp (snd <$> roleExpls) bs' $ EffTy Pure dTy --- === Inferer interface === - -class ( MonadFail1 m, Fallible1 m, Catchable1 m, CtxReader1 m, Builder CoreIR m ) - => InfBuilder (m::MonadKind1) where - - -- XXX: we should almost always used the zonking `buildDeclsInf` , - -- except where it's not possible because the result isn't atom-substitutable, - -- such as the source map at the top level. - buildDeclsInfUnzonked - :: (SinkableE e, HoistableE e, RenameE e) - => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => m l (e l)) - -> m n (Abs (Nest CDecl) e n) - - buildAbsInf - :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) - => EmitsInf n - => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs CBinder e n) - -buildAbsInfWithExpl - :: (InfBuilder m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e) - => EmitsInf n - => NameHint -> Explicitness -> CType n - -> (forall l. (EmitsInf l, DExt n l) => CAtomVar l -> m l (e l)) - -> m n (Abs (WithExpl CBinder) e n) -buildAbsInfWithExpl hint expl ty cont = do - Abs b e <- buildAbsInf hint expl ty cont - return $ Abs (WithAttrB expl b) e - -buildNaryAbsInfWithExpl - :: (Inferer m, SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) - => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest (WithExpl CBinder)) e n) -buildNaryAbsInfWithExpl expls bs cont = do - Abs bs' e <- buildNaryAbsInf expls bs cont - return $ Abs (zipAttrs expls bs') e - -buildNaryAbsInf - :: (SinkableE e, HoistableE e, RenameE e, SubstE AtomSubstVal e, Inferer m) - => EmitsInf n - => [Explicitness] -> EmptyAbs (Nest CBinder) n - -> (forall l. (EmitsInf l, DExt n l) => [CAtomVar l] -> m i l (e l)) - -> m i n (Abs (Nest CBinder) e n) -buildNaryAbsInf [] (Abs Empty UnitE) cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -buildNaryAbsInf (expl:expls) (Abs (Nest (b:>ty) bs) UnitE) cont = - prependAbs <$> buildAbsInf (getNameHint b) expl ty \v -> do - bs' <- applyRename (b@>atomVarName v) (Abs bs UnitE) - buildNaryAbsInf expls bs' \vs -> cont (sink v:vs) -buildNaryAbsInf _ _ _ = error "zip error" - -buildDeclsInf - :: (SubstE AtomSubstVal e, RenameE e, Solver m, InfBuilder m) - => (SinkableE e, HoistableE e) - => EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => m l (e l)) - -> m n (Abs (Nest CDecl) e n) -buildDeclsInf cont = buildDeclsInfUnzonked $ cont >>= zonk - -type InfBuilder2 (m::MonadKind2) = forall i. InfBuilder (m i) - -class (SubstReader Name m, InfBuilder2 m, Solver2 m) - => Inferer (m::MonadKind2) where - liftSolverMInf :: EmitsInf o => SolverM o a -> m i o a - addDefault :: CAtomName o -> DefaultType -> m i o () - getDefaults :: m i o (Defaults o) - -applyDefaults :: EmitsInf o => InfererM i o () -applyDefaults = do - defaults <- getDefaults - applyDefault (intDefaults defaults) (BaseTy $ Scalar Int32Type) - applyDefault (natDefaults defaults) NatTy - where - applyDefault ds ty = - forM_ (nameSetToList ds) \v -> do - v' <- toAtomVar v - tryConstrainEq (Var v') (Type ty) - -withApplyDefaults :: EmitsInf o => InfererM i o a -> InfererM i o a -withApplyDefaults cont = cont <* applyDefaults -{-# INLINE withApplyDefaults #-} - --- === Concrete Inferer monad === - -data InfOutMap (n::S) = - InfOutMap - (Env n) - (SolverSubst n) - (Defaults n) - -- the subset of the names in the bindings whose definitions may contain - -- inference vars (this is so we can avoid zonking everything in scope when - -- we zonk bindings) - (UnsolvedEnv n) - -- allowed effects - (EffectRow CoreIR n) - -data DefaultType = IntDefault | NatDefault - -data Defaults (n::S) = Defaults - { intDefaults :: NameSet n -- Set of names that should be defaulted to Int32 - , natDefaults :: NameSet n } -- Set of names that should be defaulted to Nat32 - -instance Semigroup (Defaults n) where - Defaults d1 d2 <> Defaults d1' d2' = Defaults (d1 <> d1') (d2 <> d2') - -instance Monoid (Defaults n) where - mempty = Defaults mempty mempty - -instance SinkableE Defaults where - sinkingProofE _ _ = todoSinkableProof -instance HoistableE Defaults where - freeVarsE (Defaults d1 d2) = d1 <> d2 -instance RenameE Defaults where - renameE env (Defaults d1 d2) = Defaults (substDefaultSet d1) (substDefaultSet d2) - where - substDefaultSet d = freeVarsE $ renameE env $ ListE $ nameSetToList @(AtomNameC CoreIR) d +-- === Inferer monad === -instance Pretty (Defaults n) where - pretty (Defaults ints nats) = - attach "Names defaulting to Int32" (nameSetToList @(AtomNameC CoreIR) ints) - <+> attach "Names defaulting to Nat32" (nameSetToList @(AtomNameC CoreIR) nats) - where - attach _ [] = mempty - attach s l = s <+> pretty l +newtype SolverSubst n = SolverSubst { fromSolverSubst :: M.Map (CAtomName n) (CAtom n) } -zonkDefaults :: SolverSubst n -> Defaults n -> Defaults n -zonkDefaults s (Defaults d1 d2) = - Defaults (zonkDefaultSet d1) (zonkDefaultSet d2) - where - zonkDefaultSet d = flip foldMap (nameSetToList @(AtomNameC CoreIR) d) \v -> - case lookupSolverSubst s v of - Rename v' -> freeVarsE v' - SubstVal (Var v') -> freeVarsE v' - _ -> mempty - -data InfOutFrag (n::S) (l::S) = InfOutFrag (InfEmissions n l) (Defaults l) (Constraints l) - -instance Pretty (InfOutFrag n l) where - pretty (InfOutFrag emissions defaults solverSubst) = - vcat [ "Pending emissions:" <+> pretty (unRNest emissions) - , "Defaults:" <+> pretty defaults - , "Solver substitution:" <+> pretty solverSubst - ] - -type InfEmission = EitherE (DeclBinding CoreIR) SolverBinding -type InfEmissions = RNest (BinderP (AtomNameC CoreIR) InfEmission) - -instance GenericB InfOutFrag where - type RepB InfOutFrag = PairB InfEmissions (LiftB (PairE Defaults Constraints)) - fromB (InfOutFrag emissions defaults solverSubst) = - PairB emissions (LiftB (PairE defaults solverSubst)) - toB (PairB emissions (LiftB (PairE defaults solverSubst))) = - InfOutFrag emissions defaults solverSubst - -instance ProvesExt InfOutFrag -instance RenameB InfOutFrag -instance BindsNames InfOutFrag -instance SinkableB InfOutFrag -instance HoistableB InfOutFrag - -instance OutFrag InfOutFrag where - emptyOutFrag = InfOutFrag REmpty mempty mempty - catOutFrags (InfOutFrag em ds ss) (InfOutFrag em' ds' ss') = - withExtEvidence em' $ - InfOutFrag (em >>> em') (sink ds <> ds') (sink ss <> ss') - -instance HasScope InfOutMap where - toScope (InfOutMap bindings _ _ _ _) = toScope bindings - -instance OutMap InfOutMap where - emptyOutMap = InfOutMap emptyOutMap emptySolverSubst mempty mempty Pure - -instance ExtOutMap InfOutMap EnvFrag where - extendOutMap (InfOutMap bindings ss dd oldUn effs) frag = - withExtEvidence frag do - let newUn = UnsolvedEnv $ getAtomNames frag - let newEnv = bindings `extendOutMap` frag - -- As an optimization, only do the zonking for the new stuff. - let (zonkedUn, zonkedEnv) = zonkUnsolvedEnv (sink ss) newUn newEnv - InfOutMap zonkedEnv (sink ss) (sink dd) (sink oldUn <> zonkedUn) (sink effs) - -newtype UnsolvedEnv (n::S) = - UnsolvedEnv { fromUnsolvedEnv :: S.Set (CAtomName n) } - deriving (Semigroup, Monoid) - -instance SinkableE UnsolvedEnv where - sinkingProofE = todoSinkableProof +emptySolverSubst :: SolverSubst n +emptySolverSubst = SolverSubst mempty -getAtomNames :: Distinct l => EnvFrag n l -> S.Set (CAtomName l) -getAtomNames frag = S.fromList $ nameSetToList $ toNameSet $ toScopeFrag frag - --- TODO: zonk the allowed effects and synth candidates in the bindings too --- TODO: the reason we need this is that `getType` uses the bindings to obtain --- type information, and we need this information when we emit decls. For --- example, if we emit `f x` and we don't know that `f` has a type of the form --- `a -> b` then `getType` will crash. But we control the inference-specific --- implementation of `emitDecl`, so maybe we could instead do something like --- emit a fresh inference variable in the case thea `getType` fails. --- XXX: It might be tempting to add a check for empty solver substs here, --- but please don't do that! We use this function to filter overestimates of --- UnsolvedEnv, and for performance reasons we should do that even when the --- SolverSubst is empty. -zonkUnsolvedEnv :: Distinct n => SolverSubst n -> UnsolvedEnv n -> Env n - -> (UnsolvedEnv n, Env n) -zonkUnsolvedEnv ss unsolved env = - flip runState env $ execWriterT do - forM_ (S.toList $ fromUnsolvedEnv unsolved) \v -> do - flip lookupEnvPure v . topEnv <$> get >>= \case - AtomNameBinding rhs -> do - let rhs' = zonkAtomBindingWithOutMap (InfOutMap env ss mempty mempty Pure) rhs - modify \e -> e {topEnv = updateEnv v (AtomNameBinding rhs') (topEnv e)} - let rhsHasInfVars = runEnvReaderM env $ hasInferenceVars rhs' - when rhsHasInfVars $ tell $ UnsolvedEnv $ S.singleton v - --- TODO: we need this shim because top level emissions can't implement `SubstE --- AtomSubstVal` so GHC doesn't know how to zonk them. If we split up top-level --- emissions from local ones in the name color system then we won't have this --- problem. -zonkAtomBindingWithOutMap - :: Distinct n => InfOutMap n -> AtomBinding CoreIR n -> AtomBinding CoreIR n -zonkAtomBindingWithOutMap outMap = \case - LetBound e -> LetBound $ zonkWithOutMap outMap e - MiscBound e -> MiscBound $ zonkWithOutMap outMap e - SolverBound e -> SolverBound $ zonkWithOutMap outMap e - NoinlineFun t e -> NoinlineFun (zonkWithOutMap outMap t) (zonkWithOutMap outMap e) - FFIFunBound x y -> FFIFunBound (zonkWithOutMap outMap x) (zonkWithOutMap outMap y) - --- TODO: Wouldn't it be faster to carry the set of inference-emitted names in the out map? -hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool -hasInferenceVars e = liftEnvReaderM $ anyInferenceVars $ freeAtomVarsList e -{-# INLINE hasInferenceVars #-} +data InfState (n::S) = InfState + { givens :: Givens n + , infEffects :: EffectRow CoreIR n } -anyInferenceVars :: [CAtomName n] -> EnvReaderM n Bool -anyInferenceVars = \case - [] -> return False - (v:vs) -> isInferenceVar v >>= \case - True -> return True - False -> anyInferenceVars vs +newtype InfererM (i::S) (o::S) (a:: *) = InfererM + { runInfererM' :: SubstReaderT Name (ReaderT1 InfState (BuilderT CoreIR (ExceptT (State TypeInfo)))) i o a } + deriving (Functor, Applicative, Monad, MonadFail, Alternative, Builder CoreIR, + EnvExtender, ScopableBuilder CoreIR, + ScopeReader, EnvReader, Fallible, Catchable, SubstReader Name) + +type InfererCPSB b i o a = (forall o'. DExt o o' => b o o' -> InfererM i o' a) -> InfererM i o a +type InfererCPSB2 b i i' o a = (forall o'. DExt o o' => b o o' -> InfererM i' o' a) -> InfererM i o a + +liftInfererM :: (EnvReader m, TopLogger m, Fallible1 m) => InfererM n n a -> m n a +liftInfererM cont = do + (ansExcept, typeInfo) <- liftInfererMPure cont + emitLog $ Outputs [SourceInfo $ SITypeInfo typeInfo] + liftExcept ansExcept +{-# INLINE liftInfererM #-} -isInferenceVar :: EnvReader m => CAtomName n -> m n Bool -isInferenceVar v = lookupEnv v >>= \case - AtomNameBinding (SolverBound _) -> return True - _ -> return False +liftInfererMPure :: EnvReader m => InfererM n n a -> m n (Except a, TypeInfo) +liftInfererMPure cont = do + ansM <- liftBuilderT $ runReaderT1 emptyInfState $ runSubstReaderT idSubst $ runInfererM' cont + return $ runState (runExceptT ansM) mempty + where + emptyInfState :: InfState n + emptyInfState = InfState (Givens HM.empty) Pure +{-# INLINE liftInfererMPure #-} -instance ExtOutMap InfOutMap InfOutFrag where - extendOutMap m (InfOutFrag em ds' cs) = do - let InfOutMap env ss ds us effs = m `extendOutMap` toEnvFrag em - let ds'' = sink ds <> ds' - let (env', us', ss') = extendOutMapWithConstraints env us ss cs - InfOutMap env' ss' ds'' us' effs - -extendOutMapWithConstraints - :: Distinct n => Env n -> UnsolvedEnv n -> SolverSubst n -> Constraints n - -> (Env n, UnsolvedEnv n, SolverSubst n) -extendOutMapWithConstraints env us ss (Constraints allCs) = case tryUnsnoc allCs of - Nothing -> (env, us, ss) - Just (cs, (v, x)) -> do - let (env', us', SolverSubst ss') = extendOutMapWithConstraints env us ss (Constraints cs) - let s = M.singleton v x - let (us'', env'') = zonkUnsolvedEnv (SolverSubst s) us' env' - let ss'' = fmap (applySolverSubstE env'' (SolverSubst s)) ss' - let ss''' = SolverSubst $ ss'' <> s - (env'', us'', ss''') +-- === Solver monad === -newtype InfererM (i::S) (o::S) (a:: *) = InfererM - { runInfererM' :: SubstReaderT Name (InplaceT InfOutMap InfOutFrag FallibleM) i o a } - deriving (Functor, Applicative, Monad, MonadFail, - ScopeReader, Fallible, Catchable, CtxReader, SubstReader Name) +type Solution = PairE CAtomName CAtom +newtype SolverDiff (n::S) = SolverDiff (RListE Solution n) + deriving (MonoidE, SinkableE, HoistableE, RenameE) +type SolverM i o a = DiffStateT1 SolverSubst SolverDiff (InfererM i) o a -liftInfererMSubst :: (Fallible2 m, SubstReader Name m, EnvReader2 m) - => InfererM i o a -> m i o a -liftInfererMSubst cont = do - env <- unsafeGetEnv - subst <- getSubst - Distinct <- getDistinct - (InfOutFrag REmpty _ _, result) <- - liftExcept $ runFallibleM $ runInplaceT (initInfOutMap env) $ - runSubstReaderT subst $ runInfererM' $ cont - return result - -liftInfererM :: (EnvReader m, Fallible1 m) - => InfererM n n a -> m n a -liftInfererM cont = runSubstReaderT idSubst $ liftInfererMSubst $ cont -{-# INLINE liftInfererM #-} +type Zonkable e = (HasNamesE e, SubstE AtomSubstVal e) -runLocalInfererM - :: SinkableE e - => (forall l. (EmitsInf l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (Abs InfOutFrag e n) -runLocalInfererM cont = InfererM $ SubstReaderT $ ReaderT \env -> do - locallyMutableInplaceT (do - Distinct <- getDistinct - EmitsInf <- fabricateEmitsInfEvidenceM - runSubstReaderT (sink env) $ runInfererM' cont) - (\d e -> return $ Abs d e) -{-# INLINE runLocalInfererM #-} - -initInfOutMap :: Env n -> InfOutMap n -initInfOutMap bindings = - InfOutMap bindings emptySolverSubst mempty (UnsolvedEnv mempty) Pure - -newtype InfDeclEmission (n::S) (l::S) = InfDeclEmission (BinderP (AtomNameC CoreIR) InfEmission n l) -instance ExtOutMap InfOutMap InfDeclEmission where - extendOutMap env (InfDeclEmission d) = env `extendOutMap` toEnvFrag d - {-# INLINE extendOutMap #-} -instance ExtOutFrag InfOutFrag InfDeclEmission where - extendOutFrag (InfOutFrag ems ds ss) (InfDeclEmission em) = - withSubscopeDistinct em $ InfOutFrag (RNest ems em) (sink ds) (sink ss) - {-# INLINE extendOutFrag #-} - -emitInfererM :: Mut o => NameHint -> InfEmission o -> InfererM i o (CAtomVar o) -emitInfererM hint emission = do - v <- InfererM $ SubstReaderT $ lift $ freshExtendSubInplaceT hint \b -> - (InfDeclEmission (b :> emission), binderName b) - return $ AtomVar v $ getType emission -{-# INLINE emitInfererM #-} - -instance Solver (InfererM i) where - extendSolverSubst v ty = do - InfererM $ SubstReaderT $ lift $ - void $ extendTrivialInplaceT $ - InfOutFrag REmpty mempty (singleConstraint v ty) - {-# INLINE extendSolverSubst #-} - - zonk e = InfererM $ SubstReaderT $ lift do - Distinct <- getDistinct - solverOutMap <- getOutMapInplaceT - return $ zonkWithOutMap solverOutMap e - {-# INLINE zonk #-} - - emitSolver binding = emitInfererM (getNameHint @String "?") $ RightE binding - {-# INLINE emitSolver #-} - - solveLocal cont = do - Abs (InfOutFrag unsolvedInfVars _ _) result <- dceInfFrag =<< runLocalInfererM cont - case unRNest unsolvedInfVars of - Empty -> return result - Nest (b:>RightE (InfVarBound ty (ctx, desc))) _ -> addSrcContext ctx $ - throw TypeErr $ formatAmbiguousVarErr (binderName b) ty desc - _ -> error "shouldn't be possible" - -formatAmbiguousVarErr :: CAtomName n -> CType n' -> InfVarDesc -> String -formatAmbiguousVarErr infVar ty = \case - AnnotationInfVar v -> - "Couldn't infer type of unannotated binder " <> v - ImplicitArgInfVar (f, argName) -> - "Couldn't infer implicit argument " <> argName <> " of " <> f - TypeInstantiationInfVar t -> - "Couldn't infer instantiation of type " <> t - MiscInfVar -> - "Ambiguous type variable: " ++ pprint infVar ++ ": " ++ pprint ty - -instance InfBuilder (InfererM i) where - buildDeclsInfUnzonked cont = do - InfererM $ SubstReaderT $ ReaderT \env -> do - Abs frag result <- locallyMutableInplaceT (do - Emits <- fabricateEmitsEvidenceM - EmitsInf <- fabricateEmitsInfEvidenceM - runSubstReaderT (sink env) $ runInfererM' cont) - (\d e -> return $ Abs d e) - extendInplaceT =<< hoistThroughDecls frag result - - buildAbsInf hint expl ty cont = do - ab <- InfererM $ SubstReaderT $ ReaderT \env -> do - extendInplaceT =<< withFreshBinder hint ty \bWithTy@(b:>_) -> do - ab <- locallyMutableInplaceT (do - v <- sinkM $ binderVar bWithTy - extendInplaceTLocal (extendSynthCandidatesInf expl $ atomVarName v) do - EmitsInf <- fabricateEmitsInfEvidenceM - -- zonking is needed so that dceInfFrag works properly - runSubstReaderT (sink env) (runInfererM' $ cont v >>= zonk)) - (\d e -> return $ Abs d e) - ab' <- dceInfFrag ab - refreshAbs ab' \infFrag result -> do - case exchangeBs $ PairB b infFrag of - HoistSuccess (PairB infFrag' b') -> do - return $ withSubscopeDistinct b' $ - Abs infFrag' $ Abs b' result - HoistFailure vs -> do - throw EscapedNameErr $ (pprint vs) - ++ "\nFailed to exchange binders in buildAbsInf" - ++ "\n" ++ pprint infFrag - Abs b e <- return ab - ty' <- zonk ty - return $ Abs (b:>ty') e - -dceInfFrag - :: (EnvReader m, EnvExtender m, Fallible1 m, RenameE e, HoistableE e) - => Abs InfOutFrag e n -> m n (Abs InfOutFrag e n) -dceInfFrag ab@(Abs frag@(InfOutFrag bs _ _) e) = - case bs of - REmpty -> return ab - _ -> hoistThroughDecls frag e >>= \case - Abs frag' (Abs Empty e') -> return $ Abs frag' e' - _ -> error "Shouldn't have any decls without `Emits` constraint" - -instance Inferer InfererM where - liftSolverMInf m = InfererM $ SubstReaderT $ lift $ - liftBetweenInplaceTs (liftExcept . liftM fromJust . runSearcherM) id liftSolverOutFrag $ - runSolverM' m - {-# INLINE liftSolverMInf #-} - - addDefault v defaultType = - InfererM $ SubstReaderT $ lift $ - extendTrivialInplaceT $ InfOutFrag REmpty defaults mempty - where - defaults = case defaultType of - IntDefault -> Defaults (freeVarsE v) mempty - NatDefault -> Defaults mempty (freeVarsE v) - - getDefaults = InfererM $ SubstReaderT $ lift do - InfOutMap _ _ defaults _ _ <- getOutMapInplaceT - return defaults - -instance Builder CoreIR (InfererM i) where - rawEmitDecl hint ann expr = do - -- This zonking, and the zonking of the bindings elsewhere, is only to - -- prevent `getType` from failing. But maybe we should just catch the - -- failure if it occurs and generate a fresh inference name for the type in - -- that case? - expr' <- zonk expr - emitInfererM hint $ LeftE $ DeclBinding ann expr' - {-# INLINE rawEmitDecl #-} - -getAllowedEffects :: InfererM i n (EffectRow CoreIR n) -getAllowedEffects = do - InfOutMap _ _ _ _ effs <- InfererM $ SubstReaderT $ lift $ getOutMapInplaceT - return effs - -withoutEffects :: InfererM i o a -> InfererM i o a -withoutEffects cont = withAllowedEffects Pure cont +liftSolverM :: SolverM i o a -> InfererM i o a +liftSolverM cont = fst <$> runDiffStateT1 emptySolverSubst cont -withAllowedEffects :: EffectRow CoreIR o -> InfererM i o a -> InfererM i o a -withAllowedEffects effs cont = do - InfererM $ SubstReaderT $ ReaderT \env -> do - extendInplaceTLocal (\(InfOutMap x y z w _) -> InfOutMap x y z w effs) do - runSubstReaderT env $ runInfererM' do - cont +solverFail :: SolverM i o a +solverFail = empty + +zonk :: Zonkable e => e n -> SolverM i n (e n) +zonk e = do + s <- getDiffState + applySolverSubst s e +{-# INLINE zonk #-} + +zonkStuck :: CStuck n -> SolverM i n (CAtom n) +zonkStuck stuck = do + solverSubst <- getDiffState + Distinct <- getDistinct + env <- unsafeGetEnv + let subst = newSubst (lookupSolverSubst solverSubst) + return $ substStuck (env, subst) stuck -type InferenceNameBinders = Nest (BinderP (AtomNameC CoreIR) SolverBinding) - --- When we finish building a block of decls we need to hoist the local solver --- information into the outer scope. If the local solver state mentions local --- variables which are about to go out of scope then we emit a "escaped scope" --- error. To avoid false positives, we clean up as much dead (i.e. solved) --- solver state as possible. -hoistThroughDecls - :: ( RenameE e, HoistableE e, Fallible1 m, ScopeReader m, EnvExtender m) - => InfOutFrag n l - -> e l - -> m n (Abs InfOutFrag (Abs (Nest CDecl) e) n) -hoistThroughDecls outFrag result = do +applySolverSubst :: (EnvReader m, Zonkable e) => SolverSubst n -> e n -> m n (e n) +applySolverSubst subst e = do + Distinct <- getDistinct env <- unsafeGetEnv - refreshAbs (Abs outFrag result) \outFrag' result' -> do - liftExcept $ hoistThroughDecls' env outFrag' result' -{-# INLINE hoistThroughDecls #-} - -hoistThroughDecls' - :: (HoistableE e, Distinct l) - => Env n - -> InfOutFrag n l - -> e l - -> Except (Abs InfOutFrag (Abs (Nest CDecl) e) n) -hoistThroughDecls' env (InfOutFrag emissions defaults constraints) result = do - withSubscopeDistinct emissions do - let subst = constraintsToSubst (env `extendOutMap` toEnvFrag emissions) constraints - HoistedSolverState infVars defaults' subst' decls result' <- - hoistInfStateRec env emissions emptyInferenceNameBindersFV - (zonkDefaults subst defaults) (UnhoistedSolverSubst emptyOutFrag subst) Empty result - let constraints' = substToConstraints subst' - let hoistedInfFrag = InfOutFrag (infNamesToEmissions infVars) defaults' constraints' - return $ Abs hoistedInfFrag $ Abs decls result' - -constraintsToSubst :: Distinct n => Env n -> Constraints n -> SolverSubst n -constraintsToSubst env (Constraints csTop) = case tryUnsnoc csTop of - Nothing -> emptySolverSubst - Just (cs, (v, x)) -> do - let SolverSubst m = constraintsToSubst env (Constraints cs) - let s = M.singleton v x - SolverSubst $ fmap (applySolverSubstE env (SolverSubst s)) m <> s - -substToConstraints :: SolverSubst n -> Constraints n -substToConstraints (SolverSubst m) = Constraints $ toSnocList $ M.toList m - -data HoistedSolverState e n where - HoistedSolverState - :: InferenceNameBinders n l1 - -> Defaults l1 - -> SolverSubst l1 - -> Nest CDecl l1 l2 - -> e l2 - -> HoistedSolverState e n - --- XXX: Be careful how you construct DelayedSolveNests! When the substitution is --- applied, the pieces are concatenated through regular map concatenation, not --- through recursive substitutions as in catSolverSubsts! This is safe to do when --- the individual SolverSubsts come from a projection of a larger SolverSubst, --- which is how we use them in `hoistInfStateRec`. -type DelayedSolveNest (b::B) (n::S) (l::S) = Nest (EitherB b (LiftB SolverSubst)) n l - -resolveDelayedSolve :: Distinct l => Env n -> SolverSubst n -> DelayedSolveNest CDecl n l -> Nest CDecl n l -resolveDelayedSolve env subst = \case - Empty -> Empty - Nest (RightB (LiftB sfrag)) rest -> resolveDelayedSolve env (subst `unsafeCatSolverSubst` sfrag) rest - Nest (LeftB (Let b rhs) ) rest -> - withSubscopeDistinct rest $ withSubscopeDistinct b $ - Nest (Let b (applySolverSubstE env subst rhs)) $ - resolveDelayedSolve (env `extendOutMap` toEnvFrag (b:>rhs)) (sink subst) rest - where - unsafeCatSolverSubst :: SolverSubst n -> SolverSubst n -> SolverSubst n - unsafeCatSolverSubst (SolverSubst a) (SolverSubst b) = SolverSubst $ a <> b - -data InferenceNameBindersFV (n::S) (l::S) = InferenceNameBindersFV (NameSet n) (InferenceNameBinders n l) -instance BindsNames InferenceNameBindersFV where - toScopeFrag = toScopeFrag . dropInferenceNameBindersFV -instance BindsEnv InferenceNameBindersFV where - toEnvFrag = toEnvFrag . dropInferenceNameBindersFV -instance ProvesExt InferenceNameBindersFV where - toExtEvidence = toExtEvidence . dropInferenceNameBindersFV -instance HoistableB InferenceNameBindersFV where - freeVarsB (InferenceNameBindersFV fvs _) = fvs - -emptyInferenceNameBindersFV :: InferenceNameBindersFV n n -emptyInferenceNameBindersFV = InferenceNameBindersFV mempty Empty - -dropInferenceNameBindersFV :: InferenceNameBindersFV n l -> InferenceNameBinders n l -dropInferenceNameBindersFV (InferenceNameBindersFV _ bs) = bs - -prependNameBinder - :: BinderP (AtomNameC CoreIR) SolverBinding n q - -> InferenceNameBindersFV q l -> InferenceNameBindersFV n l -prependNameBinder b (InferenceNameBindersFV fvs bs) = - InferenceNameBindersFV (freeVarsB b <> hoistFilterNameSet b fvs) (Nest b bs) - --- XXX: Stashing Distinct here is a little naughty, since that's generally not allowed. --- Here it should be ok, because it's only used in hoistInfStateRec, which doesn't emit. -data UnhoistedSolverSubst (n::S) where - UnhoistedSolverSubst :: Distinct l => ScopeFrag n l -> SolverSubst l -> UnhoistedSolverSubst n - -delayedHoistSolverSubst :: BindsNames b => b n l -> UnhoistedSolverSubst l -> UnhoistedSolverSubst n -delayedHoistSolverSubst b (UnhoistedSolverSubst frag s) = UnhoistedSolverSubst (toScopeFrag b >>> frag) s - -hoistSolverSubst :: UnhoistedSolverSubst n -> HoistExcept (SolverSubst n) -hoistSolverSubst (UnhoistedSolverSubst frag s) = hoist frag s - --- TODO: Instead of delaying the solve, compute the most-nested scope once --- and then use it for all _eager_ substitutions while hoisting! Using a super-scope --- for substitution shouldn't be a problem! -hoistInfStateRec - :: forall n l l1 l2 e. (Distinct n, Distinct l2, HoistableE e) - => Env n -> InfEmissions n l - -> InferenceNameBindersFV l l1 -> Defaults l1 -> UnhoistedSolverSubst l1 - -> DelayedSolveNest CDecl l1 l2 -> e l2 - -> Except (HoistedSolverState e n) -hoistInfStateRec env emissions !infVars defaults !subst decls e = case emissions of - REmpty -> do - subst' <- liftHoistExcept' "Failed to hoist solver substitution in hoistInfStateRec" - $ hoistSolverSubst subst - let decls' = withSubscopeDistinct decls $ - resolveDelayedSolve (env `extendOutMap` toEnvFrag infVars) subst' decls - return $ HoistedSolverState (dropInferenceNameBindersFV infVars) defaults subst' decls' e - RNest rest (b :> infEmission) -> do - withSubscopeDistinct decls do - case infEmission of - RightE binding@(InfVarBound _ _) -> do - UnhoistedSolverSubst frag (SolverSubst substMap) <- return subst - let vHoist :: CAtomName l1 = withSubscopeDistinct infVars $ sink $ binderName b -- binder name at l1 - let vUnhoist = withExtEvidence frag $ sink vHoist -- binder name below frag - case M.lookup vUnhoist substMap of - -- Unsolved inference variables are just gathered as they are. - Nothing -> - hoistInfStateRec env rest (prependNameBinder (b:>binding) infVars) - defaults subst decls e - -- If a variable is solved, we eliminate it. - Just bSolutionUnhoisted -> do - bSolution <- - liftHoistExcept' "Failed to eliminate solved variable in hoistInfStateRec " - $ hoist frag bSolutionUnhoisted - case exchangeBs $ PairB b infVars of - -- This used to be accepted by the code at some point (and handled the same way - -- as the Nothing) branch above, but I don't understand why. We don't even seem - -- to be exercising it anyway, so throw a not implemented error for now. - HoistFailure _ -> throw NotImplementedErr "Unzonked unsolved variables" - HoistSuccess (PairB infVars' b') -> do - let defaults' = hoistDefaults b' defaults - let bZonkedDecls = Nest (RightB (LiftB $ SolverSubst $ M.singleton vHoist bSolution)) decls -#ifdef DEX_DEBUG - -- Hoist the subst eagerly, unlike the unsafe implementation. - hoistedSubst@(SolverSubst hoistMap) <- liftHoistExcept $ hoistSolverSubst subst - let subst' = withSubscopeDistinct b' $ UnhoistedSolverSubst (toScopeFrag b') $ - SolverSubst $ M.delete vHoist hoistMap - -- Zonk the decls with `v @> bSolution` to make sure hoisting will succeed. - -- This is quadratic, which is why we don't do this in the fast implementation! - let allEmissions = RNest rest (b :> infEmission) - let declsScope = withSubscopeDistinct infVars $ - (env `extendOutMap` toEnvFrag allEmissions) `extendOutMap` toEnvFrag infVars - let resolvedDecls = resolveDelayedSolve declsScope hoistedSubst bZonkedDecls - PairB resolvedDecls' b'' <- liftHoistExcept $ exchangeBs $ PairB b' resolvedDecls - let decls' = fmapNest LeftB resolvedDecls' - -- NB: We assume that e is hoistable above e! This has to be taken - -- care of by zonking the result before this function is entered. - e' <- liftHoistExcept $ hoist b'' e - withSubscopeDistinct b'' $ - hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#else - -- SolverSubst should be recursively zonked, so any v that's a member - -- should never appear in an rhs. Hence, deleting the entry corresponding to - -- v should hoist the substitution above b'. - let subst' = unsafeCoerceE $ UnhoistedSolverSubst frag $ SolverSubst $ M.delete vUnhoist substMap - -- Applying the substitution `v @> bSolution` would eliminate `b` from decls, so this - -- is equivalent to hoisting above b'. This is of course not reflected in the type - -- system, which is why we use unsafe coercions. - let decls' = unsafeCoerceB bZonkedDecls - -- This is much more sketchy, but it reflects the e-hoistability assumption - -- that our safe implementation makes as well. Except here it's obviously unchecked. - let e' :: e UnsafeS = unsafeCoerceE e - Distinct <- return $ fabricateDistinctEvidence @UnsafeS - hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#endif - RightE (SkolemBound _) -> do -#ifdef DEX_DEBUG - PairB infVars' b' <- liftHoistExcept' "Skolem leak?" $ exchangeBs $ PairB b infVars - defaults' <- liftHoistExcept' "Skolem leak?" $ hoist b' defaults - let subst' = delayedHoistSolverSubst b' subst - PairB decls' b'' <- liftHoistExcept' "Skolem leak?" $ exchangeBs $ PairB b' decls - e' <- liftHoistExcept' "Skolem leak?" $ hoist b'' e - withSubscopeDistinct b'' $ hoistInfStateRec env rest infVars' defaults' subst' decls' e' -#else - -- Skolem vars are only instantiated in unification, and we're very careful to - -- never let them leak into the types of inference vars emitted while unifying - -- and into the solver subst. - Distinct <- return $ fabricateDistinctEvidence @UnsafeS - hoistInfStateRec @n @UnsafeS @UnsafeS @UnsafeS - env - (unsafeCoerceB rest) (unsafeCoerceB infVars) - (unsafeCoerceE defaults) (unsafeCoerceE subst) - (unsafeCoerceB decls) (unsafeCoerceE e) -#endif - LeftE emission -> do - -- Move the binder below all unsolved inference vars. Failure to do so is - -- an inference error --- a variable cannot be solved once we exit the env - -- of all variables it mentions in its type. - -- TODO: Shouldn't this be an ambiguous type error? - PairB infVars' (b':>emission') <- - liftHoistExcept' "Failed to move binder below unsovled inference vars" - $ exchangeBs (PairB (b:>emission) infVars) - -- Now, those are real leakage errors. We never want to leak this var through a solution! - -- But since we delay hoisting, they will only be raised later. - let subst' = delayedHoistSolverSubst b' subst - let defaults' = hoistDefaults b' defaults - let decls' = Nest (LeftB (Let b' emission')) decls - hoistInfStateRec env rest infVars' defaults' subst' decls' e - -hoistDefaults :: BindsNames b => b n l -> Defaults l -> Defaults n -hoistDefaults b (Defaults d1 d2) = Defaults (hoistFilterNameSet b d1) - (hoistFilterNameSet b d2) - -infNamesToEmissions :: InferenceNameBinders n l -> InfEmissions n l -infNamesToEmissions = go REmpty + return $ fmapNames env (lookupSolverSubst subst) e +{-# INLINE applySolverSubst #-} + +withFreshBinderInf :: NameHint -> Explicitness -> CType o -> InfererCPSB CBinder i o a +withFreshBinderInf hint expl ty cont = + withFreshBinder hint ty \b -> do + givens <- case expl of + Inferred _ (Synth _) -> return [toAtom $ binderVar b] + _ -> return [] + extendGivens givens $ cont b +{-# INLINE withFreshBinderInf #-} + +withFreshBindersInf + :: (SinkableE e, RenameE e) + => [Explicitness] -> Abs (Nest CBinder) e o + -> (forall o'. DExt o o' => Nest CBinder o o' -> e o' -> InfererM i o' a) + -> InfererM i o a +withFreshBindersInf explsTop (Abs bsTop e) contTop = + runSubstReaderT idSubst $ go explsTop bsTop \bs' -> do + e' <- renameM e + liftSubstReaderT $ contTop bs' e' + where + go :: [Explicitness] -> Nest CBinder ii ii' + -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name (InfererM i) ii' o' a) + -> SubstReaderT Name (InfererM i) ii o a + go [] Empty cont = withDistinct $ cont Empty + go (expl:expls) (Nest b bs) cont = do + ty <- renameM $ binderType b + SubstReaderT \s -> withFreshBinderInf (getNameHint b) expl ty \b' -> do + runSubstReaderT (sink s) $ extendSubst (b@>binderName b') do + go expls bs \bs' -> cont (Nest b' bs') + go _ _ _ = error "zip error" +{-# INLINE withFreshBindersInf #-} + +withInferenceVar + :: (Zonkable e, Emits o, ToBinding binding (AtomNameC CoreIR)) => NameHint -> binding o + -> (forall o'. (Emits o', DExt o o') => CAtomName o' -> SolverM i o' (e o', CAtom o')) + -> SolverM i o (e o) +withInferenceVar hint binding cont = diffStateT1 \s -> do + declsAndAns <- withFreshBinder hint binding \(b:>_) -> do + hardHoist b <$> buildScoped do + v <- sinkM $ binderName b + s' <- sinkM s + (PairE ans soln, diff) <- runDiffStateT1 s' do + toPairE <$> cont v + let subst = SolverSubst $ M.singleton v soln + ans' <- applySolverSubst subst ans + diff' <- applySolutionToDiff subst v diff + return $ PairE ans' diff' + fromPairE <$> emitDecls declsAndAns where - go :: InfEmissions n q -> InferenceNameBinders q l -> InfEmissions n l - go acc = \case - Empty -> acc - Nest (b:>binding) rest -> go (RNest acc (b:>RightE binding)) rest - -instance EnvReader (InfererM i) where - unsafeGetEnv = do - InfOutMap bindings _ _ _ _ <- InfererM $ SubstReaderT $ lift $ getOutMapInplaceT - return bindings - {-# INLINE unsafeGetEnv #-} - -instance EnvExtender (InfererM i) where - refreshAbs ab cont = InfererM $ SubstReaderT $ ReaderT \env -> do - refreshAbs ab \b e -> runSubstReaderT (sink env) $ runInfererM' $ cont b e - {-# INLINE refreshAbs #-} - --- === helpers for extending synthesis candidates === - --- TODO: we should pull synth candidates out of the Env and then we can treat it --- like an ordinary reader without all this ceremony. - -extendSynthCandidatesInf :: Explicitness -> CAtomName n -> InfOutMap n -> InfOutMap n -extendSynthCandidatesInf c v (InfOutMap env x y z w) = - InfOutMap (extendSynthCandidates c v env) x y z w -{-# INLINE extendSynthCandidatesInf #-} - -extendSynthCandidates :: Explicitness -> CAtomName n -> Env n -> Env n -extendSynthCandidates (Inferred _ (Synth _)) v (Env topEnv (ModuleEnv a b scs)) = - Env topEnv (ModuleEnv a b scs') - where scs' = scs <> SynthCandidates [v] mempty -extendSynthCandidates _ _ env = env -{-# INLINE extendSynthCandidates #-} - -extendSynthCandidatess :: Distinct n => [Explicitness] -> Nest CBinder n' n -> Env n -> Env n -extendSynthCandidatess (expl:expls) (Nest b bs) env = - extendSynthCandidatess expls bs env' - where env' = extendSynthCandidates expl (withExtEvidence bs $ sink $ binderName b) env -extendSynthCandidatess [] Empty env = env -extendSynthCandidatess _ _ _ = error "zip error" -{-# INLINE extendSynthCandidatess #-} + applySolutionToDiff :: SolverSubst n -> CAtomName n -> SolverDiff n -> InfererM i n (SolverDiff n) + applySolutionToDiff subst vSoln (SolverDiff (RListE (ReversedList cs))) = do + SolverDiff . RListE . ReversedList <$> forMFilter cs \(PairE v x) -> + case v == vSoln of + True -> return Nothing + False -> Just . PairE v <$> applySolverSubst subst x +{-# INLINE withInferenceVar #-} + +withFreshUnificationVar + :: (Zonkable e, Emits o) => SrcId -> InfVarDesc -> Kind CoreIR o + -> (forall o'. (Emits o', DExt o o') => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshUnificationVar sid desc k cont = do + withInferenceVar "_unif_" (InfVarBound k) \v -> do + ans <- toAtomVar v >>= cont + soln <- (M.lookup v <$> fromSolverSubst <$> getDiffState) >>= \case + Just soln -> return soln + Nothing -> throw sid $ AmbiguousInferenceVar (pprint v) (pprint k) desc + return (ans, soln) +{-# INLINE withFreshUnificationVar #-} + +withFreshUnificationVarNoEmits + :: (Zonkable e) => SrcId -> InfVarDesc -> Kind CoreIR o + -> (forall o'. (DExt o o') => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshUnificationVarNoEmits sid desc k cont = diffStateT1 \s -> do + Abs Empty resultAndDiff <- buildScoped do + liftM toPairE $ runDiffStateT1 (sink s) $ + withFreshUnificationVar sid desc (sink k) cont + return $ fromPairE resultAndDiff + +withFreshDictVar + :: (Zonkable e, Emits o) => CType o + -- This tells us how to synthesize the dict. The supplied CType won't contain inference vars. + -> (forall o'. ( DExt o o') => CType o' -> SolverM i o' (CAtom o')) + -> (forall o'. (Emits o', DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshDictVar dictTy synthIt cont = hasInferenceVars dictTy >>= \case + False -> withDistinct $ synthIt dictTy >>= cont + True -> withInferenceVar "_dict_" (DictBound dictTy) \v -> do + ans <- cont =<< (toAtom <$> toAtomVar v) + dictTy' <- zonk $ sink dictTy + dict <- synthIt dictTy' + return (ans, dict) +{-# INLINE withFreshDictVar #-} + +withFreshDictVarNoEmits + :: (Zonkable e) => CType o + -- This tells us how to synthesize the dict. The supplied CType won't contain inference vars. + -> (forall o'. (DExt o o') => CType o' -> SolverM i o' (CAtom o')) + -> (forall o'. (DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshDictVarNoEmits dictTy synthIt cont = diffStateT1 \s -> do + Abs Empty resultAndDiff <- buildScoped do + liftM toPairE $ runDiffStateT1 (sink s) $ + withFreshDictVar (sink dictTy) synthIt cont + return $ fromPairE resultAndDiff +{-# INLINE withFreshDictVarNoEmits #-} + +withDict + :: (Zonkable e, Emits o) => SrcId -> CType o + -> (forall o'. (Emits o', DExt o o') => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withDict sid dictTy cont = withFreshDictVar dictTy + (\dictTy' -> lift11 $ trySynthTerm sid dictTy' Full) + cont +{-# INLINE withDict#-} + +addConstraint :: CAtomName o -> CAtom o -> SolverM i o () +addConstraint v ty = updateDiffStateM (SolverDiff $ RListE $ toSnocList [PairE v ty]) +{-# INLINE addConstraint #-} + +getInfState :: InfererM i o (InfState o) +getInfState = InfererM $ liftSubstReaderT ask +{-# INLINE getInfState #-} + +withInfState :: (InfState o -> InfState o) -> InfererM i o a -> InfererM i o a +withInfState f cont = InfererM $ local f (runInfererM' cont) +{-# INLINE withInfState #-} --- === actual inference pass === +withAllowedEffects :: EffectRow CoreIR o -> InfererM i o a -> InfererM i o a +withAllowedEffects effs cont = withInfState (\(InfState g _) -> InfState g effs) cont +{-# INLINE withAllowedEffects #-} -data RequiredTy (e::E) (n::S) = - Check (e n) - | Infer - deriving Show +emitTypeInfo :: SrcId -> String -> InfererM i o () +emitTypeInfo sid ty = do + InfererM $ liftSubstReaderT $ lift11 $ lift1 $ lift do + modify \(TypeInfo m) -> TypeInfo $ M.insert sid ty m -checkSigma :: EmitsBoth o - => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkSigma hint expr sTy = confuseGHC >>= \_ -> case sTy of - Pi piTy@(CorePiType _ expls _ _) -> do - if all (== Explicit) expls - then fallback - else case expr of - WithSrcE src (ULam lam) -> addSrcContext src $ Lam <$> checkULam lam piTy - _ -> Lam <$> buildLamInf piTy \args resultTy -> do - explicits <- return $ catMaybes $ args <&> \case - (Explicit, arg) -> Just arg - _ -> Nothing - expr' <- inferWithoutInstantiation expr >>= zonk - dropSubst $ checkOrInferApp expr' explicits [] (Check resultTy) - DepPairTy depPairTy -> case depPairTy of - DepPairType ImplicitDepPair (_ :> lhsTy) _ -> do - -- TODO: check for the case that we're given some of the implicit dependent pair args explicitly - lhsVal <- Var <$> freshInferenceName MiscInfVar lhsTy - -- TODO: make an InfVarDesc case for dep pair instantiation - rhsTy <- instantiate depPairTy [lhsVal] - rhsVal <- checkSigma noHint expr rhsTy - return $ DepPair lhsVal rhsVal depPairTy - _ -> fallback - _ -> fallback - where fallback = checkOrInferRho hint expr (Check sTy) - -inferSigma :: EmitsBoth o => NameHint -> UExpr i -> InfererM i o (CAtom o) -inferSigma hint (WithSrcE pos expr) = case expr of - ULam lam -> addSrcContext pos $ Lam <$> inferULam lam - _ -> inferRho hint (WithSrcE pos expr) - -checkRho :: EmitsBoth o => - NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkRho hint expr ty = checkOrInferRho hint expr (Check ty) -{-# INLINE checkRho #-} - -inferRho :: EmitsBoth o => - NameHint -> UExpr i -> InfererM i o (CAtom o) -inferRho hint expr = checkOrInferRho hint expr Infer -{-# INLINE inferRho #-} - -getImplicitArg :: EmitsInf o => InferenceArgDesc -> InferenceMechanism -> CType o -> InfererM i o (CAtom o) -getImplicitArg desc inf argTy = case inf of - Unify -> Var <$> freshInferenceName (ImplicitArgInfVar desc) argTy - Synth reqMethodAccess -> do - ctx <- srcPosCtx <$> getErrCtx - return $ DictHole (AlwaysEqual ctx) argTy reqMethodAccess +withReducibleEmissions + :: (HasNamesE e, SubstE AtomSubstVal e, ToErr err) + => SrcId -> err + -> (forall o' . (Emits o', DExt o o') => InfererM i o' (e o')) + -> InfererM i o (e o) +withReducibleEmissions sid msg cont = do + withDecls <- buildScoped cont + reduceWithDecls withDecls >>= \case + Just t -> return t + _ -> throw sid msg +{-# INLINE withReducibleEmissions #-} -withBlockDecls - :: EmitsBoth o - => UBlock i -> (forall i'. UExpr i' -> InfererM i' o a) -> InfererM i o a -withBlockDecls (WithSrcE src (UBlock declsTop result)) contTop = - addSrcContext src $ go declsTop $ contTop result where - go :: EmitsBoth o => Nest UDecl i i' -> InfererM i' o a -> InfererM i o a - go decls cont = case decls of - Empty -> cont - Nest d ds -> withUDecl d $ go ds $ cont +-- === actual inference pass === -withUDecl - :: EmitsBoth o - => UDecl i i' - -> InfererM i' o a - -> InfererM i o a -withUDecl (WithSrcB src d) cont = addSrcContext src case d of - UPass -> cont - UExprDecl e -> inferSigma noHint e >> cont - ULet letAnn p ann rhs -> do - val <- checkMaybeAnnExpr (getNameHint p) ann rhs - var <- emitDecl (getNameHint p) letAnn $ Atom val - bindLetPat p var cont +data RequiredTy (n::S) = + Check (CType n) + | Infer + deriving Show --- "rho" means the required type here should not be (at the top level) an implicit pi type or --- an implicit dependent pair type. We don't want to unify those directly. --- The name hint names the object being computed -checkOrInferRho - :: forall i o. EmitsBoth o - => NameHint -> UExpr i -> RequiredTy CType o -> InfererM i o (CAtom o) -checkOrInferRho hint uExprWithSrc@(WithSrcE pos expr) reqTy = do - addSrcContext pos $ confuseGHC >>= \_ -> case expr of - UVar _ -> inferAndInstantiate - ULit l -> matchRequirement $ Con $ Lit l - ULam lamExpr -> do +data PartialPiType (n::S) where + PartialPiType + :: AppExplicitness -> [Explicitness] + -> Nest CBinder n l + -> EffectRow CoreIR l + -> RequiredTy l + -> PartialPiType n + +data PartialType (n::S) = + PartialType (PartialPiType n) + | FullType (CType n) + +checkOrInfer :: Emits o => RequiredTy o -> UExpr i -> InfererM i o (CAtom o) +checkOrInfer reqTy expr = case reqTy of + Infer -> bottomUp expr + Check t -> topDown t expr + +topDown :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) +topDown ty uexpr = topDownPartial (typeAsPartialType ty) uexpr + +topDownPartial :: Emits o => PartialType o -> UExpr i -> InfererM i o (CAtom o) +topDownPartial partialTy exprWithSrc@(WithSrcE sid expr) = + case partialTy of + PartialType partialPiTy -> case expr of + ULam lam -> toAtom <$> Lam <$> checkULamPartial partialPiTy sid lam + _ -> toAtom <$> Lam <$> etaExpandPartialPi partialPiTy \resultTy explicitArgs -> do + expr' <- bottomUpExplicit exprWithSrc + dropSubst $ checkOrInferApp sid sid expr' explicitArgs [] resultTy + FullType ty -> topDownExplicit ty exprWithSrc + +-- Creates a lambda for all args and returns (via CPA) the explicit args +etaExpandPartialPi + :: PartialPiType o + -> (forall o'. (Emits o', DExt o o') => RequiredTy o' -> [CAtom o'] -> InfererM i o' (CAtom o')) + -> InfererM i o (CoreLamExpr o) +etaExpandPartialPi (PartialPiType appExpl expls bs effs reqTy) cont = do + withFreshBindersInf expls (Abs bs (PairE effs reqTy)) \bs' (PairE effs' reqTy') -> do + let args = zip expls (toAtom <$> bindersVars bs') + explicits <- return $ catMaybes $ args <&> \case + (Explicit, arg) -> Just arg + _ -> Nothing + withAllowedEffects effs' do + body <- buildBlock $ cont (sink reqTy') (sink <$> explicits) + let piTy = CorePiType appExpl expls bs' (EffTy effs' $ getType body) + return $ CoreLamExpr piTy $ LamExpr bs' body + +-- Doesn't introduce implicit pi binders or dependent pairs +topDownExplicit :: forall i o. Emits o => CType o -> UExpr i -> InfererM i o (CAtom o) +topDownExplicit reqTy exprWithSrc@(WithSrcE sid expr) = case expr of + ULam lamExpr -> case reqTy of + TyCon (Pi piTy) -> toAtom <$> Lam <$> checkULam sid lamExpr piTy + _ -> throw sid $ UnexpectedTerm "lambda" (pprint reqTy) + UFor dir uFor@(UForExpr b _) -> case reqTy of + TyCon (TabPi tabPiTy) -> do + lam@(UnaryLamExpr b' _) <- checkUForExpr sid uFor tabPiTy + ixTy <- asIxType (getSrcId b) $ binderType b' + emitHof $ For dir ixTy lam + _ -> throw sid $ UnexpectedTerm "`for` expression" (pprint reqTy) + UApp f posArgs namedArgs -> do + f' <- bottomUpExplicit f + checkOrInferApp sid (getSrcId f) f' posArgs namedArgs (Check reqTy) + UDepPair lhs rhs -> case reqTy of + TyCon (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do + lhs' <- checkSigmaDependent lhs (FullType lhsTy) + rhsTy <- instantiate ty [lhs'] + rhs' <- topDown rhsTy rhs + return $ toAtom $ DepPair lhs' rhs' ty + _ -> throw sid $ UnexpectedTerm "dependent pair" (pprint reqTy) + UCase scrut alts -> do + scrut' <- bottomUp scrut + let scrutTy = getType scrut' + alts' <- mapM (checkCaseAlt (Check reqTy) scrutTy) alts + buildSortedCase scrut' alts' reqTy + UDo block -> withBlockDecls block \result -> topDownExplicit (sink reqTy) result + UTabCon xs -> do case reqTy of - Check (Pi piTy) -> Lam <$> checkULam lamExpr piTy - Check _ -> Lam <$> inferULam lamExpr >>= matchRequirement - Infer -> Lam <$> inferULam lamExpr - UFor dir uFor -> do - lam@(UnaryLamExpr b' _) <- case reqTy of - Check (TabPi tabPiTy) -> do checkUForExpr uFor tabPiTy - Check _ -> inferUForExpr uFor - Infer -> inferUForExpr uFor - ixTy <- asIxType $ binderType b' - result <- emitHof $ For dir ixTy lam - matchRequirement result + TyCon (TabPi tabPiTy) -> checkTabCon tabPiTy sid xs + _ -> throw sid $ UnexpectedTerm "table constructor" (pprint reqTy) + UNatLit x -> fromNatLit sid x reqTy + UIntLit x -> fromIntLit sid x reqTy + UPrim UTuple xs -> case reqTy of + TyKind -> toAtom . ProdType <$> mapM checkUType xs + TyCon (ProdType reqTys) -> do + when (length reqTys /= length xs) $ throw sid $ TupleLengthMismatch (length reqTys) (length xs) + toAtom <$> ProdCon <$> forM (zip reqTys xs) \(reqTy', x) -> topDown reqTy' x + _ -> throw sid $ UnexpectedTerm "tuple" (pprint reqTy) + UFieldAccess _ _ -> infer + UVar _ -> infer + UTypeAnn _ _ -> infer + UTabApp _ _ -> infer + UFloatLit _ -> infer + UPrim _ _ -> infer + ULit _ -> infer + UPi _ -> infer + UTabPi _ -> infer + UDepPairTy _ -> infer + UHole -> throw sid InferHoleErr + where + infer :: InfererM i o (CAtom o) + infer = do + sigmaAtom <- maybeInterpretPunsAsTyCons (Check reqTy) =<< bottomUpExplicit exprWithSrc + instantiateSigma sid (Check reqTy) sigmaAtom + +bottomUp :: Emits o => UExpr i -> InfererM i o (CAtom o) +bottomUp expr = bottomUpExplicit expr >>= instantiateSigma (getSrcId expr) Infer + +-- Doesn't instantiate implicit args +bottomUpExplicit :: Emits o => UExpr i -> InfererM i o (SigmaAtom o) +bottomUpExplicit (WithSrcE sid expr) = case expr of + UVar ~(InternalName _ sn v) -> do + v' <- renameM v + ty <- getUVarType v' + emitTypeInfo sid $ pprint sn ++ " : " ++ pprint ty + return $ SigmaUVar sn ty v' + ULit l -> return $ SigmaAtom Nothing $ Con $ Lit l + UFieldAccess x (WithSrc _ field) -> do + x' <- bottomUp x + ty <- return $ getType x' + fields <- getFieldDefs sid ty + case M.lookup field fields of + Just def -> case def of + FieldProj i -> SigmaAtom Nothing <$> projectField i x' + FieldDotMethod method (TyConParams _ params) -> do + method' <- toAtomVar method + resultTy <- partialAppType (getType method') (params ++ [x']) + return $ SigmaPartialApp resultTy (toAtom method') (params ++ [x']) + Nothing -> throw sid $ CantFindField (pprint field) (pprint ty) (map pprint $ M.keys fields) + ULam lamExpr -> SigmaAtom Nothing <$> toAtom <$> inferULam lamExpr + UFor dir uFor@(UForExpr b _) -> do + lam@(UnaryLamExpr b' _) <- inferUForExpr uFor + ixTy <- asIxType (getSrcId b) $ binderType b' + liftM (SigmaAtom Nothing) $ emitHof $ For dir ixTy lam UApp f posArgs namedArgs -> do - f' <- inferWithoutInstantiation f >>= zonk - checkOrInferApp f' posArgs namedArgs reqTy + f' <- bottomUpExplicit f + SigmaAtom Nothing <$> checkOrInferApp sid (getSrcId f) f' posArgs namedArgs Infer UTabApp tab args -> do - tab' <- inferRho noHint tab >>= zonk - inferTabApp (srcPos tab) tab' args >>= matchRequirement + tab' <- bottomUp tab + SigmaAtom Nothing <$> inferTabApp (getSrcId tab) tab' args UPi (UPiExpr bs appExpl effs ty) -> do -- TODO: check explicitness constraints - ab <- withUBinders bs \_ -> EffTy <$> checkUEffRow effs <*> checkUType ty - Abs bs' effTy' <- return ab - let (expls, bs'') = unzipAttrs bs' - matchRequirement $ Type $ Pi $ CorePiType appExpl expls bs'' effTy' - UTabPi (UTabPiExpr (UAnnBinder b ann cs) ty) -> do - unless (null cs) $ throw TypeErr "`=>` shouldn't have constraints" - ann' <- asIxType =<< checkAnn (getSourceName b) ann - piTy <- case b of - UIgnore -> - buildTabPiInf noHint ann' \_ -> checkUType ty - _ -> buildTabPiInf (getNameHint b) ann' \v -> extendRenamer (b@>atomVarName v) do - let msg = "Can't reduce type expression: " ++ docAsStr (pretty ty) - Type rhs <- withReducibleEmissions msg $ Type <$> checkUType ty - return rhs - matchRequirement $ Type $ TabPi piTy - UDepPairTy (UDepPairType expl (UAnnBinder b ann cs) rhs) -> do - unless (null cs) $ throw TypeErr "Dependent pair binders shouldn't have constraints" - ann' <- checkAnn (getSourceName b) ann - matchRequirement =<< liftM (Type . DepPairTy) do - buildDepPairTyInf (getNameHint b) expl ann' \v -> extendRenamer (b@>atomVarName v) do - let msg = "Can't reduce type expression: " ++ docAsStr (pretty rhs) - withReducibleEmissions msg $ checkUType rhs - UDepPair lhs rhs -> do - case reqTy of - Check (DepPairTy ty@(DepPairType _ (_ :> lhsTy) _)) -> do - lhs' <- checkSigmaDependent noHint lhs lhsTy - rhsTy <- instantiate ty [lhs'] - rhs' <- checkSigma noHint rhs rhsTy - return $ DepPair lhs' rhs' ty - _ -> throw TypeErr $ "Can't infer the type of a dependent pair; please annotate it" - UCase scrut alts -> do - scrut' <- inferRho noHint scrut - scrutTy <- return $ getType scrut' - reqTy' <- case reqTy of - Infer -> freshType - Check req -> return req - alts' <- mapM (checkCaseAlt reqTy' scrutTy) alts - scrut'' <- zonk scrut' - buildSortedCase scrut'' alts' reqTy' - UDo block -> withBlockDecls block \result -> checkOrInferRho hint result reqTy - UTabCon xs -> inferTabCon hint xs reqTy >>= matchRequirement - UHole -> case reqTy of - Infer -> throw MiscErr "Can't infer type of hole" - Check ty -> freshAtom ty + withUBinders bs \(ZipB expls bs') -> do + effTy' <- EffTy <$> checkUEffRow effs <*> checkUType ty + return $ SigmaAtom Nothing $ toAtom $ + Pi $ CorePiType appExpl expls bs' effTy' + UTabPi (UTabPiExpr b ty) -> do + Abs b' ty' <- withUBinder b \(WithAttrB _ b') -> + liftM (Abs b') $ checkUType ty + d <- getIxDict (getSrcId b) $ binderType b' + let piTy = TabPiType d b' ty' + return $ SigmaAtom Nothing $ toAtom $ TabPi piTy + UDepPairTy (UDepPairType expl b rhs) -> do + withUBinder b \(WithAttrB _ b') -> do + rhs' <- checkUType rhs + return $ SigmaAtom Nothing $ toAtom $ DepPairTy $ DepPairType expl b' rhs' + UDepPair _ _ -> throw sid InferDepPairErr + UCase scrut (alt:alts) -> do + scrut' <- bottomUp scrut + let scrutTy = getType scrut' + alt'@(IndexedAlt _ altAbs) <- checkCaseAlt Infer scrutTy alt + Abs b ty <- liftEnvReaderM $ refreshAbs altAbs \b body -> do + return $ Abs b (getType body) + resultTy <- liftHoistExcept sid $ hoist b ty + alts' <- mapM (checkCaseAlt (Check resultTy) scrutTy) alts + SigmaAtom Nothing <$> buildSortedCase scrut' (alt':alts') resultTy + UCase _ [] -> throw sid InferEmptyCaseEff + UDo block -> withBlockDecls block \result -> bottomUpExplicit result + UTabCon xs -> liftM (SigmaAtom Nothing) $ inferTabCon sid xs UTypeAnn val ty -> do - ty' <- zonk =<< checkUType ty - val' <- checkSigma hint val ty' - matchRequirement val' - UPrim UTuple xs -> case reqTy of - Check TyKind -> Type . ProdTy <$> mapM checkUType xs - _ -> do - xs' <- mapM (inferRho noHint) xs - matchRequirement $ ProdVal xs' + ty' <- checkUType ty + liftM (SigmaAtom Nothing) $ topDown ty' val + UPrim UTuple xs -> do + xs' <- forM xs \x -> bottomUp x + return $ SigmaAtom Nothing $ Con $ ProdCon xs' UPrim UMonoLiteral [WithSrcE _ l] -> case l of - UIntLit x -> matchRequirement $ Con $ Lit $ Int32Lit $ fromIntegral x - UNatLit x -> matchRequirement $ Con $ Lit $ Word32Lit $ fromIntegral x - _ -> throw MiscErr "argument to %monoLit must be a literal" + UIntLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Int32Lit $ fromIntegral x + UNatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Word32Lit $ fromIntegral x + _ -> throwInternal "argument to %monoLit must be a literal" UPrim UExplicitApply (f:xs) -> do - f' <- inferWithoutInstantiation f - xs' <- mapM (inferRho noHint) xs - applySigmaAtom f' xs' >>= matchRequirement + f' <- bottomUpExplicit f + xs' <- mapM bottomUp xs + SigmaAtom Nothing <$> applySigmaAtom sid f' xs' UPrim UProjNewtype [x] -> do - x' <- inferRho hint x >>= emitHinted hint . Atom - unwrapNewtype $ Var x' + x' <- bottomUp x >>= unwrapNewtype + return $ SigmaAtom Nothing x' UPrim prim xs -> do xs' <- forM xs \x -> do inferPrimArg x >>= \case - Var v -> lookupAtomName (atomVarName v) >>= \case + Stuck _ (Var v) -> lookupAtomName (atomVarName v) >>= \case LetBound (DeclBinding _ (Atom e)) -> return e - _ -> return $ Var v + _ -> return $ toAtom v x' -> return x' - matchRequirement =<< matchPrimApp prim xs' - UFieldAccess _ _ -> inferAndInstantiate - UNatLit x -> do - let defaultVal = Con $ Lit $ Word32Lit $ fromIntegral x - let litVal = Con $ Lit $ Word64Lit $ fromIntegral x - matchRequirement =<< applyFromLiteralMethod "from_unsigned_integer" defaultVal NatDefault litVal - UIntLit x -> do - let defaultVal = Con $ Lit $ Int32Lit $ fromIntegral x - let litVal = Con $ Lit $ Int64Lit $ fromIntegral x - matchRequirement =<< applyFromLiteralMethod "from_integer" defaultVal IntDefault litVal - UFloatLit x -> matchRequirement $ Con $ Lit $ Float32Lit $ realToFrac x - -- TODO: Make sure that this conversion is not lossy! + liftM (SigmaAtom Nothing) $ matchPrimApp prim xs' + UNatLit l -> liftM (SigmaAtom Nothing) $ fromNatLit sid l NatTy + UIntLit l -> liftM (SigmaAtom Nothing) $ fromIntLit sid l (BaseTy $ Scalar Int32Type) + UFloatLit x -> return $ SigmaAtom Nothing $ Con $ Lit $ Float32Lit $ realToFrac x + UHole -> throw sid InferHoleErr + +expectEq :: (PrettyE e, AlphaEqE e) => SrcId -> e o -> e o -> InfererM i o () +expectEq sid reqTy actualTy = alphaEq reqTy actualTy >>= \case + True -> return () + False -> throw sid $ TypeMismatch (pprint reqTy) (pprint actualTy) +{-# INLINE expectEq #-} + +fromIntLit :: Emits o => SrcId -> Int -> CType o -> InfererM i o (CAtom o) +fromIntLit sid x ty = do + let litVal = Con $ Lit $ Int64Lit $ fromIntegral x + applyFromLiteralMethod sid ty "from_integer" litVal + +fromNatLit :: Emits o => SrcId -> Word64 -> CType o -> InfererM i o (CAtom o) +fromNatLit sid x ty = do + let litVal = Con $ Lit $ Word64Lit $ fromIntegral x + applyFromLiteralMethod sid ty "from_unsigned_integer" litVal + +instantiateSigma :: Emits o => SrcId -> RequiredTy o -> SigmaAtom o -> InfererM i o (CAtom o) +instantiateSigma sid reqTy sigmaAtom = case sigmaAtom of + SigmaUVar _ _ _ -> case getType sigmaAtom of + TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy))) -> do + bsConstrained <- buildConstraints (Abs bs resultTy) \_ resultTy' -> do + case reqTy of + Infer -> return [] + Check reqTy' -> return [TypeConstraint sid (sink reqTy') resultTy'] + args <- inferMixedArgs @UExpr sid fDesc expls bsConstrained ([], []) + applySigmaAtom sid sigmaAtom args + _ -> fallback + _ -> fallback where - matchRequirement :: CAtom o -> InfererM i o (CAtom o) - matchRequirement x = return x <* - case reqTy of - Infer -> return () - Check req -> do - ty <- return $ getType x - constrainTypesEq req ty - {-# INLINE matchRequirement #-} - - inferAndInstantiate :: InfererM i o (CAtom o) - inferAndInstantiate = do - sigmaAtom <- maybeInterpretPunsAsTyCons reqTy =<< inferWithoutInstantiation uExprWithSrc - instantiateSigma sigmaAtom >>= matchRequirement - {-# INLINE inferAndInstantiate #-} - -applyFromLiteralMethod :: EmitsBoth n => SourceName -> CAtom n -> DefaultType -> CAtom n -> InfererM i n (CAtom n) -applyFromLiteralMethod methodName defaultVal defaultTy litVal = do + fallback = forceSigmaAtom sid sigmaAtom >>= matchReq sid reqTy + fDesc = getSourceName sigmaAtom + +matchReq :: Ext o o' => SrcId -> RequiredTy o -> CAtom o' -> InfererM i o' (CAtom o') +matchReq sid (Check reqTy) x = do + reqTy' <- sinkM reqTy + return x <* expectEq sid reqTy' (getType x) +matchReq _ Infer x = return x +{-# INLINE matchReq #-} + +forceSigmaAtom :: Emits o => SrcId -> SigmaAtom o -> InfererM i o (CAtom o) +forceSigmaAtom sid sigmaAtom = case sigmaAtom of + SigmaAtom _ x -> return x + SigmaUVar _ _ v -> case v of + UAtomVar v' -> inlineTypeAliases v' + _ -> applySigmaAtom sid sigmaAtom [] + SigmaPartialApp _ _ _ -> error "not implemented" -- better error message? + +withBlockDecls + :: (Emits o, Zonkable e) + => UBlock i + -> (forall i' o'. (Emits o', DExt o o') => UExpr i' -> InfererM i' o' (e o')) + -> InfererM i o (e o) +withBlockDecls (WithSrcE _ (UBlock declsTop result)) contTop = + go declsTop $ contTop result where + go :: (Emits o, Zonkable e) + => Nest UDecl i i' + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) + go decls cont = case decls of + Empty -> withDistinct cont + Nest d ds -> withUDecl d $ go ds $ cont + +withUDecl + :: (Emits o, Zonkable e) + => UDecl i i' + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) +withUDecl (WithSrcB _ d) cont = case d of + UPass -> withDistinct cont + UExprDecl e -> withDistinct $ bottomUp e >> cont + ULet letAnn p ann rhs -> do + val <- checkMaybeAnnExpr ann rhs + let letAnn' = considerInlineAnn letAnn (getType val) + var <- emitDecl (getNameHint p) letAnn' $ Atom val + bindLetPat p var cont + +considerInlineAnn :: LetAnn -> CType n -> LetAnn +considerInlineAnn PlainLet TyKind = InlineLet +considerInlineAnn PlainLet (TyCon (Pi (CorePiType _ _ _ (EffTy Pure TyKind)))) = InlineLet +considerInlineAnn ann _ = ann + +applyFromLiteralMethod + :: Emits n => SrcId -> CType n -> SourceName -> CAtom n -> InfererM i n (CAtom n) +applyFromLiteralMethod sid resultTy methodName litVal = lookupSourceMap methodName >>= \case - Nothing -> return defaultVal + Nothing -> error $ "prelude function not found: " ++ pprint methodName Just ~(UMethodVar methodName') -> do MethodBinding className _ <- lookupEnv methodName' - resultTyVar <- freshInferenceName MiscInfVar TyKind - dictTy <- DictTy <$> dictType className [Var resultTyVar] - addDefault (atomVarName resultTyVar) defaultTy - emitExpr =<< mkApplyMethod (DictHole (AlwaysEqual emptySrcPosCtx) dictTy Full) 0 [litVal] + dictTy <- toType <$> dictType className [toAtom resultTy] + Just d <- toMaybeDict <$> trySynthTerm sid dictTy Full + emit =<< mkApplyMethod d 0 [litVal] -- atom that requires instantiation to become a rho type data SigmaAtom n = @@ -1107,242 +685,115 @@ instance HasSourceName (SigmaAtom n) where SigmaUVar sn _ _ -> sn SigmaPartialApp _ _ _ -> "" -instance SinkableE SigmaAtom where - sinkingProofE = error "it's fine, trust me" - -instance SubstE AtomSubstVal SigmaAtom where - substE env (SigmaAtom sn x) = SigmaAtom sn $ substE env x - substE env (SigmaUVar sn ty uvar) = case uvar of - UAtomVar v -> substE env $ SigmaAtom (Just sn) $ Var (AtomVar v ty) - UTyConVar v -> SigmaUVar sn ty' $ UTyConVar $ substE env v - UDataConVar v -> SigmaUVar sn ty' $ UDataConVar $ substE env v - UPunVar v -> SigmaUVar sn ty' $ UPunVar $ substE env v - UClassVar v -> SigmaUVar sn ty' $ UClassVar $ substE env v - UMethodVar v -> SigmaUVar sn ty' $ UMethodVar $ substE env v - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" - where ty' = substE env ty - substE env (SigmaPartialApp ty f xs) = - SigmaPartialApp (substE env ty) (substE env f) (map (substE env) xs) - --- XXX: this must handle the complement of the cases that `checkOrInferRho` --- handles directly or else we'll just keep bouncing between the two. -inferWithoutInstantiation - :: forall i o. EmitsBoth o - => UExpr i -> InfererM i o (SigmaAtom o) -inferWithoutInstantiation (WithSrcE pos expr) = - addSrcContext pos $ confuseGHC >>= \_ -> case expr of - UVar ~(InternalName _ sn v) -> do - v' <- renameM v - ty <- getUVarType v' - return $ SigmaUVar sn ty v' - UFieldAccess x (WithSrc pos' field) -> addSrcContext pos' do - x' <- inferRho noHint x >>= zonk - ty <- return $ getType x' - fields <- getFieldDefs ty - case M.lookup field fields of - Just def -> case def of - FieldProj i -> SigmaAtom Nothing <$> projectField i x' - FieldDotMethod method (TyConParams _ params) -> do - method' <- toAtomVar method - resultTy <- partialAppType (getType method') (params ++ [x']) - return $ SigmaPartialApp resultTy (Var method') (params ++ [x']) - Nothing -> throw TypeErr $ - "Can't resolve field " ++ pprint field ++ " of type " ++ pprint ty ++ - "\nKnown fields are: " ++ pprint (M.keys fields) - _ -> SigmaAtom Nothing <$> inferRho noHint (WithSrcE pos expr) - data FieldDef (n::S) = FieldProj Int | FieldDotMethod (CAtomName n) (TyConParams n) deriving (Show, Generic) -getFieldDefs :: CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) -getFieldDefs ty = case ty of - NewtypeTyCon (UserADTType _ tyName params) -> do - TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName - instantiateTyConDef tyDef params >>= \case - StructFields fields -> do - let projFields = enumerate fields <&> \(i, (field, _)) -> - [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] - let methodFields = M.toList dotMethods <&> \(field, f) -> - (FieldName field, FieldDotMethod f params) - return $ M.fromList $ concat projFields ++ methodFields - ADTCons _ -> noFields "" - RefTy _ valTy -> case valTy of - RefTy _ _ -> noFields "" - _ -> do - valFields <- getFieldDefs valTy - return $ M.filter isProj valFields - where isProj = \case - FieldProj _ -> True - _ -> False - ProdTy ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) - TabPi _ -> noFields "\nArray indexing uses [] now." - _ -> noFields "" - where - noFields s = throw TypeErr $ "Can't get fields for type " ++ pprint ty ++ s - -instantiateSigma :: forall i o. EmitsBoth o => SigmaAtom o -> InfererM i o (CAtom o) -instantiateSigma sigmaAtom = case getType sigmaAtom of - Pi piTy@(CorePiType ExplicitApp _ _ _) -> do - Lam <$> etaExpandExplicits fDesc piTy \args -> - applySigmaAtom (sink sigmaAtom) args - Pi (CorePiType ImplicitApp expls bs (EffTy _ resultTy)) -> do - args <- inferMixedArgs @UExpr fDesc expls (Abs bs resultTy) [] [] - applySigmaAtom sigmaAtom args - DepPairTy (DepPairType ImplicitDepPair _ _) -> - -- TODO: we should probably call instantiateSigma again here in case - -- we have nested dependent pairs. Also, it looks like this doesn't - -- get called after function application. We probably want to fix that. - fallback >>= getSnd - _ -> fallback - where - fallback = case sigmaAtom of - SigmaAtom _ x -> return x - SigmaUVar _ _ v -> case v of - UAtomVar v' -> do - v'' <- toAtomVar v' - return $ Var v'' - _ -> applySigmaAtom sigmaAtom [] - SigmaPartialApp _ _ _ -> error "shouldn't hit this case because we should have a pi type here" - fDesc :: SourceName - fDesc = getSourceName sigmaAtom +getFieldDefs :: SrcId -> CType n -> InfererM i n (M.Map FieldName' (FieldDef n)) +getFieldDefs sid ty = case ty of + StuckTy _ _ -> noFields + TyCon con -> case con of + NewtypeTyCon (UserADTType _ tyName params) -> do + TyConBinding ~(Just tyDef) (DotMethods dotMethods) <- lookupEnv tyName + instantiateTyConDef tyDef params >>= \case + StructFields fields -> do + let projFields = enumerate fields <&> \(i, (field, _)) -> + [(FieldName field, FieldProj i), (FieldNum i, FieldProj i)] + let methodFields = M.toList dotMethods <&> \(field, f) -> + (FieldName field, FieldDotMethod f params) + return $ M.fromList $ concat projFields ++ methodFields + ADTCons _ -> noFields + RefType _ valTy -> case valTy of + RefTy _ _ -> noFields + _ -> do + valFields <- getFieldDefs sid valTy + return $ M.filter isProj valFields + where isProj = \case + FieldProj _ -> True + _ -> False + ProdType ts -> return $ M.fromList $ enumerate ts <&> \(i, _) -> (FieldNum i, FieldProj i) + TabPi _ -> noFields + _ -> noFields + where noFields = throw sid $ NoFields $ pprint ty projectField :: Emits o => Int -> CAtom o -> InfererM i o (CAtom o) projectField i x = case getType x of - ProdTy _ -> projectTuple i x - NewtypeTyCon _ -> projectStruct i x - RefTy _ valTy -> case valTy of - ProdTy _ -> getProjRef (ProjectProduct i) x - NewtypeTyCon _ -> projectStructRef i x + StuckTy _ _ -> bad + TyCon con -> case con of + ProdType _ -> proj i x + NewtypeTyCon _ -> projectStruct i x + RefType _ valTy -> case valTy of + TyCon (ProdType _) -> getProjRef (ProjectProduct i) x + TyCon (NewtypeTyCon _) -> projectStructRef i x + _ -> bad _ -> bad - _ -> bad where bad = error $ "bad projection: " ++ pprint (i, x) --- creates a lambda term with just the explicit binders, but provides --- args corresponding to all the binders (explicit and implicit) -etaExpandExplicits - :: EmitsInf o => SourceName -> CorePiType o - -> (forall o'. (EmitsBoth o', DExt o o') => [CAtom o'] -> InfererM i o' (CAtom o')) - -> InfererM i o (CoreLamExpr o) -etaExpandExplicits fSourceName (CorePiType _ explsTop bsTop (EffTy effs _)) contTop = do - Abs bs body <- go explsTop bsTop \xs -> do - effs' <- applySubst (bsTop@@>(SubstVal<$>xs)) effs - withAllowedEffects effs' do - body <- buildBlockInf $ contTop $ sinkList xs - return $ PairE effs' body - let (expls, bs') = unzipAttrs bs - coreLamExpr ExplicitApp expls $ Abs bs' body - where - go :: (EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e, HoistableE e ) - => [Explicitness] -> Nest CBinder o any - -> (forall o'. (EmitsInf o', DExt o o') => [CAtom o'] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest (b:>ty) rest) cont = case expl of - Explicit -> do - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ty \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go expls rest' \args -> cont (sink (Var v) : args) - Inferred argSourceName infMech -> do - arg <- getImplicitArg (fSourceName, fromMaybe "_" argSourceName) infMech ty - Abs rest' UnitE <- applySubst (b@>SubstVal arg) $ Abs rest UnitE - go expls rest' \args -> cont (sink arg : args) - go _ _ _ = error "zip error" - -buildLamInf - :: EmitsInf o => CorePiType o - -> (forall o' . (EmitsBoth o', DExt o o') - => [(Explicitness, CAtom o')] -> CType o' -> InfererM i o' (CAtom o')) - -> InfererM i o (CoreLamExpr o) -buildLamInf (CorePiType appExpl explsTop bsTop effTy) contTop = do - ab <- go explsTop bsTop \xs -> do - let (expls, xs') = unzip xs - EffTy effs' resultTy' <- applySubst (bsTop@@>(SubstVal<$>xs')) effTy - withAllowedEffects effs' do - body <- buildBlockInf $ contTop (zip expls $ sinkList xs') (sink resultTy') - return $ PairE effs' body - coreLamExpr appExpl explsTop ab - where - go :: (EmitsInf o, HoistableE e, SinkableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> Nest CBinder o any - -> (forall o'. (EmitsInf o', DExt o o') => [(Explicitness, CAtom o')] -> InfererM i o' (e o')) - -> InfererM i o (Abs (Nest CBinder) e o) - go [] Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] - go (expl:expls) (Nest b rest) cont = do - prependAbs <$> buildAbsInf (getNameHint b) expl (binderType b) \v -> do - Abs rest' UnitE <- applyRename (b@>atomVarName v) $ Abs rest UnitE - go expls rest' \args -> cont $ (expl, sink $ Var v) : args - go _ _ _ = error "zip error" - -class ExplicitArg (e::E) where - checkExplicitArg :: EmitsBoth o => IsDependent -> e i -> CType o -> InfererM i o (CAtom o) - inferExplicitArg :: EmitsBoth o => e i -> InfererM i o (CAtom o) +class PrettyE e => ExplicitArg (e::E) where + checkExplicitNonDependentArg :: Emits o => e i -> PartialType o -> InfererM i o (CAtom o) + checkExplicitDependentArg :: e i -> PartialType o -> InfererM i o (CAtom o) + inferExplicitArg :: Emits o => e i -> InfererM i o (CAtom o) + isHole :: e n -> Bool + explicitArgSrcId :: e n -> SrcId instance ExplicitArg UExpr where - checkExplicitArg isDependent arg argTy = - if isDependent - then checkSigmaDependent noHint arg argTy - else checkSigma noHint arg argTy - - inferExplicitArg arg = inferRho noHint arg + checkExplicitDependentArg arg argTy = checkSigmaDependent arg argTy + checkExplicitNonDependentArg arg argTy = topDownPartial argTy arg + inferExplicitArg arg = bottomUp arg + isHole = \case + WithSrcE _ UHole -> True + _ -> False + explicitArgSrcId = getSrcId instance ExplicitArg CAtom where - checkExplicitArg _ arg argTy = do - arg' <- renameM arg - constrainTypesEq argTy $ getType arg' - return arg' + checkExplicitDependentArg = checkCAtom + checkExplicitNonDependentArg = checkCAtom inferExplicitArg arg = renameM arg + isHole _ = False + explicitArgSrcId _ = rootSrcId + +checkCAtom :: CAtom i -> PartialType o -> InfererM i o (CAtom o) +checkCAtom arg argTy = do + arg' <- renameM arg + case argTy of + FullType argTy' -> expectEq rootSrcId argTy' (getType arg') + PartialType _ -> return () -- TODO? + return arg' checkOrInferApp - :: forall i o arg - . (EmitsBoth o, ExplicitArg arg) - => SigmaAtom o -> [arg i] -> [(SourceName, arg i)] - -> RequiredTy CType o - -> InfererM i o (CAtom o) -checkOrInferApp f' posArgs namedArgs reqTy = do + :: forall i o arg . (Emits o, ExplicitArg arg) + => SrcId -> SrcId -> SigmaAtom o -> [arg i] -> [(SourceName, arg i)] + -> RequiredTy o -> InfererM i o (CAtom o) +checkOrInferApp appSrcId funSrcId f' posArgs namedArgs reqTy = do f <- maybeInterpretPunsAsTyCons reqTy f' case getType f of - Pi (CorePiType appExpl expls bs effTy) -> case appExpl of + TyCon (Pi piTy@(CorePiType appExpl expls _ _)) -> case appExpl of ExplicitApp -> do - checkArity expls posArgs - args' <- inferMixedArgs fDesc expls (Abs bs effTy) posArgs namedArgs - applySigmaAtom f args' >>= matchRequirement - ImplicitApp -> do - -- TODO: should this already have been done by the time we get `f`? - implicitArgs <- inferMixedArgs @UExpr fDesc expls (Abs bs effTy) [] [] - f'' <- SigmaAtom (Just fDesc) <$> applySigmaAtom f implicitArgs - checkOrInferApp f'' posArgs namedArgs Infer >>= matchRequirement - -- TODO: special-case error for when `fTy` can't possibly be a function - fTy -> do - when (not $ null namedArgs) do - throw TypeErr "Can't infer function types with named arguments" - args' <- mapM inferExplicitArg posArgs - argTys <- return $ map getType args' - resultTy <- getResultTy - let expected = nonDepPiType argTys Pure resultTy - constrainTypesEq (Pi expected) fTy - f'' <- zonk f - applySigmaAtom f'' args' + checkExplicitArity appSrcId expls posArgs + bsConstrained <- buildAppConstraints appSrcId reqTy piTy + args <- inferMixedArgs appSrcId fDesc expls bsConstrained (posArgs, namedArgs) + applySigmaAtom appSrcId f args + ImplicitApp -> error "should already have handled this case" + ty -> throw funSrcId $ EliminationErr "function type" (pprint ty) where fDesc :: SourceName fDesc = getSourceName f' - getResultTy :: InfererM i o (CType o) - getResultTy = case reqTy of - Infer -> freshType - Check req -> return req - - matchRequirement :: CAtom o -> InfererM i o (CAtom o) - matchRequirement x = return x <* - case reqTy of - Infer -> return () - Check req -> do - ty <- return $ getType x - constrainTypesEq req ty - -maybeInterpretPunsAsTyCons :: RequiredTy CType n -> SigmaAtom n -> InfererM i n (SigmaAtom n) +buildAppConstraints :: SrcId -> RequiredTy n -> CorePiType n -> InfererM i n (ConstrainedBinders n) +buildAppConstraints appSrcId reqTy (CorePiType _ _ bs effTy) = do + effsAllowed <- infEffects <$> getInfState + buildConstraints (Abs bs effTy) \_ (EffTy effs resultTy) -> do + resultTyConstraints <- return case reqTy of + Infer -> [] + Check reqTy' -> [TypeConstraint appSrcId (sink reqTy') resultTy] + EffectRow _ t <- return effs + effConstraints <- case t of + NoTail -> return [] + EffectRowTail _ -> return [EffectConstraint appSrcId (sink effsAllowed) effs] + return $ resultTyConstraints ++ effConstraints + +maybeInterpretPunsAsTyCons :: RequiredTy n -> SigmaAtom n -> InfererM i n (SigmaAtom n) maybeInterpretPunsAsTyCons (Check TyKind) (SigmaUVar sn _ (UPunVar v)) = do let v' = UTyConVar v ty <- getUVarType v' @@ -1351,16 +802,22 @@ maybeInterpretPunsAsTyCons _ x = return x type IsDependent = Bool -applySigmaAtom :: EmitsBoth o => SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) -applySigmaAtom (SigmaAtom _ f) args = emitExprWithEffects =<< mkApp f args -applySigmaAtom (SigmaUVar _ _ f) args = case f of +inlineTypeAliases :: CAtomName n -> InfererM i n (CAtom n) +inlineTypeAliases v = do + lookupAtomName v >>= \case + LetBound (DeclBinding InlineLet (Atom e)) -> return e + _ -> toAtom <$> toAtomVar v + +applySigmaAtom :: Emits o => SrcId -> SigmaAtom o -> [CAtom o] -> InfererM i o (CAtom o) +applySigmaAtom appSrcId (SigmaAtom _ f) args = emitWithEffects appSrcId =<< mkApp f args +applySigmaAtom appSrcId (SigmaUVar _ _ f) args = case f of UAtomVar f' -> do - f'' <- toAtomVar f' - emitExprWithEffects =<< mkApp (Var f'') args + f'' <- inlineTypeAliases f' + emitWithEffects appSrcId =<< mkApp f'' args UTyConVar f' -> do TyConDef sn roleExpls _ _ <- lookupTyCon f' let expls = snd <$> roleExpls - return $ Type $ NewtypeTyCon $ UserADTType sn f' (TyConParams expls args) + return $ toAtom $ UserADTType sn f' (TyConParams expls args) UDataConVar v -> do (tyCon, i) <- lookupDataCon v applyDataCon tyCon i args @@ -1369,21 +826,25 @@ applySigmaAtom (SigmaUVar _ _ f) args = case f of -- interpret as a data constructor by default (params, dataArgs) <- splitParamPrefix tc args repVal <- makeStructRepVal tc dataArgs - return $ NewtypeCon (UserADTData sn tc params) repVal - UClassVar f' -> do - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef f' - return $ Type $ DictTy $ DictType sourceName f' args + return $ toAtom $ NewtypeCon (UserADTData sn tc params) repVal + UClassVar f' -> do + ClassDef sourceName builtinName _ _ _ _ _ _ <- lookupClassDef f' + return $ toAtom case builtinName of + Just Ix -> IxDictType singleTyParam + Just Data -> DataDictType singleTyParam + Nothing -> DictType sourceName f' args + where singleTyParam = case args of + [p] -> fromJust $ toMaybeType p + _ -> error "not a single type param" UMethodVar f' -> do MethodBinding className methodIdx <- lookupEnv f' - ClassDef _ _ _ _ paramBs _ _ <- lookupClassDef className + ClassDef _ _ _ _ _ paramBs _ _ <- lookupClassDef className let numParams = nestLength paramBs -- params aren't needed because they're already implied by the dict argument let (dictArg:args') = drop numParams args - emitExprWithEffects =<< mkApplyMethod dictArg methodIdx args' - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" -applySigmaAtom (SigmaPartialApp _ f prevArgs) args = - emitExprWithEffects =<< mkApp f (prevArgs ++ args) + emitWithEffects appSrcId =<< mkApplyMethod (fromJust $ toMaybeDict dictArg) methodIdx args' +applySigmaAtom appSrcId (SigmaPartialApp _ f prevArgs) args = + emitWithEffects appSrcId =<< mkApp f (prevArgs ++ args) splitParamPrefix :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n, [CAtom n]) splitParamPrefix tc args = do @@ -1402,122 +863,205 @@ applyDataCon tc conIx topArgs = do repVal <- return case conDefs of [] -> error "unreachable" [_] -> conProd - _ -> SumVal conTys conIx conProd + _ -> Con $ SumCon conTys conIx conProd where conTys = conDefs <&> \(DataConDef _ _ rty _) -> rty - return $ NewtypeCon (UserADTData sn tc params) repVal + return $ toAtom $ NewtypeCon (UserADTData sn tc params) repVal where wrap :: EnvReader m => CType n -> [CAtom n] -> m n (CAtom n) wrap _ [arg] = return $ arg wrap rty args = case rty of - ProdTy tys -> + TyCon (ProdType tys) -> if nargs == ntys - then return $ ProdVal args - else ProdVal . (curArgs ++) . (:[]) <$> wrap (last tys) remArgs + then return $ Con $ ProdCon args + else Con . ProdCon . (curArgs ++) . (:[]) <$> wrap (last tys) remArgs where nargs = length args; ntys = length tys (curArgs, remArgs) = splitAt (ntys - 1) args - DepPairTy dpt@(DepPairType _ b rty') -> do + TyCon (DepPairTy dpt@(DepPairType _ b rty')) -> do rty'' <- applySubst (b@>SubstVal h) rty' ans <- wrap rty'' t - return $ DepPair h ans dpt + return $ toAtom $ DepPair h ans dpt where h:t = args _ -> error $ "Unexpected data con representation type: " ++ pprint rty -emitExprWithEffects :: EmitsBoth o => CExpr o -> InfererM i o (CAtom o) -emitExprWithEffects expr = do - addEffects $ getEffects expr - emitExpr expr +emitWithEffects :: Emits o => SrcId -> CExpr o -> InfererM i o (CAtom o) +emitWithEffects sid expr = do + addEffects sid $ getEffects expr + emit expr -checkArity :: [Explicitness] -> [a] -> InfererM i o () -checkArity expls args = do +checkExplicitArity :: SrcId -> [Explicitness] -> [a] -> InfererM i o () +checkExplicitArity sid expls args = do let arity = length [() | Explicit <- expls] let numArgs = length args - when (numArgs /= arity) do - throw TypeErr $ "Wrong number of positional arguments provided. Expected " ++ - pprint arity ++ " but got " ++ pprint numArgs + when (numArgs /= arity) $ throw sid $ ArityErr arity numArgs + +type MixedArgs arg = ([arg], [(SourceName, arg)]) -- positional args, named args +data Constraint (n::S) = + TypeConstraint SrcId (CType n) (CType n) + -- permitted effects (no inference vars), proposed effects + | EffectConstraint SrcId (EffectRow CoreIR n) (EffectRow CoreIR n) + +type Constraints = ListE Constraint +type ConstrainedBinders n = ([IsDependent], Abs (Nest CBinder) Constraints n) + +buildConstraints + :: HasNamesE e + => Abs (Nest CBinder) e o + -> (forall o'. DExt o o' => [CAtom o'] -> e o' -> EnvReaderM o' [Constraint o']) + -> InfererM i o (ConstrainedBinders o) +buildConstraints ab cont = liftEnvReaderM do + refreshAbs ab \bs e -> do + cs <- cont (toAtom <$> bindersVars bs) e + return (getDependence (Abs bs e), Abs bs $ ListE cs) + where + getDependence :: HasNamesE e => Abs (Nest CBinder) e n -> [IsDependent] + getDependence (Abs Empty _) = [] + getDependence (Abs (Nest b bs) e) = + (binderName b `isFreeIn` Abs bs e) : getDependence (Abs bs e) -- TODO: check that there are no extra named args provided inferMixedArgs - :: forall arg i o e - . (ExplicitArg arg, EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => SourceName -> [Explicitness] - -> Abs (Nest CBinder) e o -> [arg i] -> [(SourceName, arg i)] + :: forall arg i o . (Emits o, ExplicitArg arg) + => SrcId -> SourceName + -> [Explicitness] -> ConstrainedBinders o + -> MixedArgs (arg i) -> InfererM i o [CAtom o] -inferMixedArgs fSourceName explsTop bsAbs posArgs namedArgs = do - checkNamedArgValidity explsTop (map fst namedArgs) - liftM fst $ runStreamReaderT1 posArgs $ go explsTop bsAbs +inferMixedArgs appSrcId fSourceName explsTop (dependenceTop, bsAbs) argsTop@(_, namedArgsTop) = do + checkNamedArgValidity appSrcId explsTop (map fst namedArgsTop) + liftSolverM $ fromListE <$> go explsTop dependenceTop bsAbs argsTop where - go :: (EmitsBoth o, SubstE (SubstVal Atom) e, SinkableE e, HoistableE e) - => [Explicitness] -> Abs (Nest CBinder) e o - -> StreamReaderT1 (arg i) (InfererM i) o [CAtom o] - go [] (Abs Empty _) = return [] - go (expl:expls) (Abs (Nest b bs) result) = do - let rest = Abs bs result - let isDependent = binderName b `isFreeIn` rest - arg <- inferMixedArg isDependent (binderType b) expl - arg' <- lift11 $ zonk arg - rest' <- applySubst (b @> SubstVal arg') rest - (arg:) <$> go expls rest' - go _ _ = error "zip error" + go :: Emits oo + => [Explicitness] -> [IsDependent] -> Abs (Nest CBinder) Constraints oo -> MixedArgs (arg i) + -> SolverM i oo (ListE CAtom oo) + go expls dependence (Abs bs cs) args = do + cs' <- eagerlyApplyConstraints bs cs + case (expls, dependence, bs) of + ([], [], Empty) -> return mempty + (expl:explsRest, isDependent:dependenceRest, Nest b bsRest) -> do + inferMixedArg isDependent (binderType b) expl args \arg restArgs -> do + bs' <- applySubst (b @> SubstVal arg) (Abs bsRest cs') + (ListE [arg] <>) <$> go explsRest dependenceRest bs' restArgs + (_, _, _) -> error "zip error" + + eagerlyApplyConstraints + :: Nest CBinder oo oo' -> Constraints oo' + -> SolverM i oo (Constraints oo') + eagerlyApplyConstraints Empty (ListE cs) = mapM_ applyConstraint cs >> return (ListE []) + eagerlyApplyConstraints bs (ListE cs) = ListE <$> forMFilter cs \c -> do + case hoist bs c of + HoistSuccess c' -> case c' of + TypeConstraint _ _ _ -> applyConstraint c' >> return Nothing + EffectConstraint _ _ (EffectRow specificEffs _) -> + hasInferenceVars specificEffs >>= \case + False -> applyConstraint c' >> return Nothing + -- we delay applying the constraint in this case because we might + -- learn more about the specific effects after we've seen more + -- arguments (like a `Ref h a` that tells us about the `h`) + True -> return $ Just c + HoistFailure _ -> return $ Just c + + inferMixedArg + :: (Emits oo, Zonkable e) => IsDependent -> CType oo -> Explicitness -> MixedArgs (arg i) + -> (forall o'. (Emits o', DExt oo o') => CAtom o' -> MixedArgs (arg i) -> SolverM i o' (e o')) + -> SolverM i oo (e oo) + inferMixedArg isDependent argTy' expl args cont = do + argTy <- zonk argTy' + case expl of + Explicit -> do + -- this should succeed because we've already done the arity check + (arg:argsRest, namedArgs) <- return args + if isHole arg + then do + let desc = (pprint fSourceName, "_") + withFreshUnificationVar appSrcId (ImplicitArgInfVar desc) argTy \v -> + cont (toAtom v) (argsRest, namedArgs) + else do + arg' <- checkOrInferExplicitArg isDependent arg argTy + withDistinct $ cont arg' (argsRest, namedArgs) + Inferred argName infMech -> do + let desc = (pprint $ fSourceName, fromMaybe "_" (fmap pprint argName)) + case lookupNamedArg args argName of + Just arg -> do + arg' <- checkOrInferExplicitArg isDependent arg argTy + withDistinct $ cont arg' args + Nothing -> case infMech of + Unify -> withFreshUnificationVar appSrcId (ImplicitArgInfVar desc) argTy \v -> cont (toAtom v) args + Synth _ -> withDict appSrcId argTy \d -> cont d args + + checkOrInferExplicitArg :: Emits oo => Bool -> arg i -> CType oo -> SolverM i oo (CAtom oo) + checkOrInferExplicitArg isDependent arg argTy = do + arg' <- lift11 $ withoutInfVarsPartial argTy >>= \case + Just partialTy -> case isDependent of + True -> checkExplicitDependentArg arg partialTy + False -> checkExplicitNonDependentArg arg partialTy + Nothing -> inferExplicitArg arg + constrainEq (explicitArgSrcId arg) argTy (getType arg') + return arg' - inferMixedArg :: EmitsBoth o => IsDependent -> CType o -> Explicitness - -> StreamReaderT1 (arg i) (InfererM i) o (CAtom o) - inferMixedArg isDependent argTy = \case - Explicit -> do - -- this should succeed because we've already done the arity check - Just arg <- readStream - lift11 $ checkExplicitArg isDependent arg argTy - Inferred argName infMech -> lift11 do - case lookupNamedArg argName of - Nothing -> getImplicitArg (fSourceName, fromMaybe "_" argName) infMech argTy - Just arg -> checkExplicitArg isDependent arg argTy - - lookupNamedArg :: Maybe SourceName -> Maybe (arg i) - lookupNamedArg Nothing = Nothing - lookupNamedArg (Just v) = lookup v namedArgs - -checkNamedArgValidity :: Fallible m => [Explicitness] -> [SourceName] -> m () -checkNamedArgValidity expls offeredNames = do + lookupNamedArg :: MixedArgs x -> Maybe SourceName -> Maybe x + lookupNamedArg _ Nothing = Nothing + lookupNamedArg (_, namedArgs) (Just v) = lookup v namedArgs + + withoutInfVarsPartial :: CType n -> InfererM i n (Maybe (PartialType n)) + withoutInfVarsPartial = \case + TyCon (Pi piTy) -> + withoutInfVars piTy >>= \case + Just piTy' -> return $ Just $ PartialType $ piAsPartialPi piTy' + Nothing -> withoutInfVars $ PartialType $ piAsPartialPiDropResultTy piTy + ty -> liftM (FullType <$>) $ withoutInfVars ty + + withoutInfVars :: HoistableE e => e n -> InfererM i n (Maybe (e n)) + withoutInfVars x = hasInferenceVars x >>= \case + True -> return Nothing + False -> return $ Just x + +checkNamedArgValidity :: Fallible m => SrcId -> [Explicitness] -> [SourceName] -> m () +checkNamedArgValidity sid expls offeredNames = do let explToMaybeName = \case Explicit -> Nothing Inferred v _ -> v let acceptedNames = catMaybes $ map explToMaybeName expls let duplicates = repeated offeredNames - when (not $ null duplicates) do - throw TypeErr $ "Repeated names offered" ++ pprint duplicates + when (not $ null duplicates) $ throw sid $ RepeatedOptionalArgs $ map pprint duplicates let unrecognizedNames = filter (not . (`elem` acceptedNames)) offeredNames when (not $ null unrecognizedNames) do - throw TypeErr $ "Unrecognized named arguments: " ++ pprint unrecognizedNames - ++ "\nShould be one of: " ++ pprint acceptedNames + throw sid $ UnrecognizedOptionalArgs (map pprint unrecognizedNames) (map pprint acceptedNames) -inferPrimArg :: EmitsBoth o => UExpr i -> InfererM i o (CAtom o) +inferPrimArg :: Emits o => UExpr i -> InfererM i o (CAtom o) inferPrimArg x = do - xBlock <- buildBlockInf $ inferRho noHint x - EffTy _ ty <- blockEffTy xBlock - case ty of - TyKind -> cheapReduce xBlock >>= \case + xBlock <- buildBlock $ bottomUp x + case getType xBlock of + TyKind -> reduceExpr xBlock >>= \case Just reduced -> return reduced - _ -> throw CompilerErr "Type args to primops must be reducible" - _ -> emitBlock xBlock + _ -> throwInternal "Type args to primops must be reducible" + _ -> emit xBlock matchPrimApp :: Emits o => PrimName -> [CAtom o] -> InfererM i o (CAtom o) matchPrimApp = \case - UNat -> \case ~[] -> return $ Type $ NewtypeTyCon Nat - UFin -> \case ~[n] -> return $ Type $ NewtypeTyCon (Fin n) - UEffectRowKind -> \case ~[] -> return $ Type $ NewtypeTyCon EffectRowKind - UBaseType b -> \case ~[] -> return $ Type $ TC $ BaseType b - UNatCon -> \case ~[x] -> return $ NewtypeCon NatCon x - UPrimTC op -> \x -> Type . TC <$> matchGenericOp (Right op) x - UCon op -> \x -> Con <$> matchGenericOp (Right op) x - UMiscOp op -> \x -> emitOp =<< MiscOp <$> matchGenericOp op x - UMemOp op -> \x -> emitOp =<< MemOp <$> matchGenericOp op x - UBinOp op -> \case ~[x, y] -> emitOp $ BinOp op x y - UUnOp op -> \case ~[x] -> emitOp $ UnOp op x - UMAsk -> \case ~[r] -> emitOp $ RefOp r MAsk - UMGet -> \case ~[r] -> emitOp $ RefOp r MGet - UMPut -> \case ~[r, x] -> emitOp $ RefOp r $ MPut x + UNat -> \case ~[] -> return $ toAtom $ NewtypeTyCon Nat + UFin -> \case ~[n] -> return $ toAtom $ NewtypeTyCon (Fin n) + UEffectRowKind -> \case ~[] -> return $ toAtom $ NewtypeTyCon EffectRowKind + UBaseType b -> \case ~[] -> return $ toAtomR $ BaseType b + UNatCon -> \case ~[x] -> return $ toAtom $ NewtypeCon NatCon x + UPrimTC tc -> case tc of + P.ProdType -> \ts -> return $ toAtom $ ProdType $ map (fromJust . toMaybeType) ts + P.SumType -> \ts -> return $ toAtom $ SumType $ map (fromJust . toMaybeType) ts + P.RefType -> \case ~[h, a] -> return $ toAtom $ RefType h (fromJust $ toMaybeType a) + P.TypeKind -> \case ~[] -> return $ Con $ TyConAtom $ TypeKind + P.HeapType -> \case ~[] -> return $ Con $ TyConAtom $ HeapType + UCon con -> case con of + P.ProdCon -> \xs -> return $ toAtom $ ProdCon xs + P.HeapVal -> \case ~[] -> return $ toAtom HeapVal + P.SumCon _ -> error "not supported" + UMiscOp op -> \x -> emit =<< MiscOp <$> matchGenericOp op x + UMemOp op -> \x -> emit =<< MemOp <$> matchGenericOp op x + UBinOp op -> \case ~[x, y] -> emit $ BinOp op x y + UUnOp op -> \case ~[x] -> emit $ UnOp op x + UMAsk -> \case ~[r] -> emit $ RefOp r MAsk + UMGet -> \case ~[r] -> emit $ RefOp r MGet + UMPut -> \case ~[r, x] -> emit $ RefOp r $ MPut x UIndexRef -> \case ~[r, i] -> indexRef r i - UApplyMethod i -> \case ~(d:args) -> emitExpr =<< mkApplyMethod d i args + UApplyMethod i -> \case ~(d:args) -> emit =<< mkApplyMethod (fromJust $ toMaybeDict d) i args ULinearize -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Linearize f' x UTranspose -> \case ~[f, x] -> do f' <- lam1 f; emitHof $ Transpose f' x URunReader -> \case ~[x, f] -> do f' <- lam2 f; emitHof $ RunReader x f' @@ -1525,13 +1069,13 @@ matchPrimApp = \case UWhile -> \case ~[f] -> do f' <- lam0 f; emitHof $ While f' URunIO -> \case ~[f] -> do f' <- lam0 f; emitHof $ RunIO f' UCatchException-> \case ~[f] -> do f' <- lam0 f; emitHof =<< mkCatchException f' - UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emitOp $ RefOp r $ MExtend (BaseMonoid z f') x + UMExtend -> \case ~[r, z, f, x] -> do f' <- lam2 f; emit $ RefOp r $ MExtend (BaseMonoid z f') x URunWriter -> \args -> do [idVal, combiner, f] <- return args combiner' <- lam2 combiner f' <- lam2 f emitHof $ RunWriter Nothing (BaseMonoid idVal combiner') f' - p -> \case xs -> throw TypeErr $ "Bad primitive application: " ++ show (p, xs) + p -> \case xs -> throwInternal $ "Bad primitive application: " ++ show (p, xs) where lam2 :: Fallible m => CAtom n -> m (LamExpr CoreIR n) lam2 x = do @@ -1543,7 +1087,7 @@ matchPrimApp = \case ExplicitCoreLam (UnaryNest b) body <- return x return $ UnaryLamExpr b body - lam0 :: Fallible m => CAtom n -> m (CBlock n) + lam0 :: Fallible m => CAtom n -> m (CExpr n) lam0 x = do ExplicitCoreLam Empty body <- return x return body @@ -1553,59 +1097,39 @@ matchPrimApp = \case (tyArgs, dataArgs) <- partitionEithers <$> forM xs \x -> do case getType x of TyKind -> do - Type x' <- return x + Just x' <- return $ toMaybeType x return $ Left x' _ -> return $ Right x return $ fromJust $ toOp $ GenericOpRep op tyArgs dataArgs [] -pattern ExplicitCoreLam :: Nest CBinder n l -> CBlock l -> CAtom n -pattern ExplicitCoreLam bs body <- Lam (CoreLamExpr _ (LamExpr bs body)) +pattern ExplicitCoreLam :: Nest CBinder n l -> CExpr l -> CAtom n +pattern ExplicitCoreLam bs body <- Con (Lam (CoreLamExpr _ (LamExpr bs body))) -- === n-ary applications === -inferTabApp :: EmitsBoth o => SrcPosCtx -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) -inferTabApp tabCtx tab args = addSrcContext tabCtx do +inferTabApp :: Emits o => SrcId -> CAtom o -> [UExpr i] -> InfererM i o (CAtom o) +inferTabApp tabSrcId tab args = do tabTy <- return $ getType tab - args' <- inferNaryTabAppArgs tabTy args - tab' <- zonk tab - emitExpr =<< mkTabApp tab' args' - -inferNaryTabAppArgs - :: EmitsBoth o - => CType o -> [UExpr i] -> InfererM i o [CAtom o] -inferNaryTabAppArgs _ [] = return [] -inferNaryTabAppArgs tabTy (arg:rest) = do - TabPiType _ b resultTy <- fromTabPiType True tabTy - let ixTy = binderType b - let isDependent = binderName b `isFreeIn` resultTy - arg' <- if isDependent - then checkSigmaDependent (getNameHint b) arg ixTy - else checkSigma (getNameHint b) arg ixTy - arg'' <- zonk arg' - resultTy' <- applySubst (b @> SubstVal arg'') resultTy - rest' <- inferNaryTabAppArgs resultTy' rest - return $ arg'':rest' - -checkSigmaDependent :: EmitsBoth o - => NameHint -> UExpr i -> CType o -> InfererM i o (CAtom o) -checkSigmaDependent hint e@(WithSrcE ctx _) ty = addSrcContext ctx $ - withReducibleEmissions depFunErrMsg $ checkSigma hint e (sink ty) - where - depFunErrMsg = - "Dependent functions can only be applied to fully evaluated expressions. " ++ - "Bind the argument to a name before you apply the function." - -withReducibleEmissions - :: ( EmitsInf o, SinkableE e, RenameE e, SubstE AtomSubstVal e - , HoistableE e, CheaplyReducibleE CoreIR e e) - => String - -> (forall o' . (EmitsBoth o', DExt o o') => InfererM i o' (e o')) - -> InfererM i o (e o) -withReducibleEmissions msg cont = do - Abs decls result <- buildDeclsInf cont - cheapReduceWithDecls decls result >>= \case - Just t -> return t - _ -> throw TypeErr msg + args' <- inferNaryTabAppArgs tabSrcId tabTy args + naryTabApp tab args' + +inferNaryTabAppArgs :: Emits o => SrcId -> CType o -> [UExpr i] -> InfererM i o [CAtom o] +inferNaryTabAppArgs _ _ [] = return [] +inferNaryTabAppArgs tabSrcId tabTy (arg:rest) = case tabTy of + TyCon (TabPi (TabPiType _ b resultTy)) -> do + let ixTy = binderType b + let isDependent = binderName b `isFreeIn` resultTy + arg' <- if isDependent + then checkSigmaDependent arg (FullType ixTy) + else topDown ixTy arg + resultTy' <- applySubst (b @> SubstVal arg') resultTy + rest' <- inferNaryTabAppArgs tabSrcId resultTy' rest + return $ arg':rest' + _ -> throw tabSrcId $ EliminationErr "table type" (pprint tabTy) + +checkSigmaDependent :: UExpr i -> PartialType o -> InfererM i o (CAtom o) +checkSigmaDependent e ty = withReducibleEmissions (getSrcId e) CantReduceDependentArg $ + topDownPartial (sink ty) e -- === sorting case alternatives === @@ -1617,19 +1141,12 @@ instance SinkableE IndexedAlt where buildNthOrderedAlt :: (Emits n, Builder CoreIR m) => [IndexedAlt n] -> CType n -> CType n -> Int -> CAtom n -> m n (CAtom n) -buildNthOrderedAlt alts scrutTy resultTy i v = do - case lookup (nthCaseAltIdx scrutTy i) [(idx, alt) | IndexedAlt idx alt <- alts] of +buildNthOrderedAlt alts _ resultTy i v = do + case lookup i [(idx, alt) | IndexedAlt idx alt <- alts] of Nothing -> do resultTy' <- sinkM resultTy - emitOp $ MiscOp $ ThrowError resultTy' - Just alt -> applyAbs alt (SubstVal v) >>= emitBlock - --- converts from the ordinal index used in the core IR to the more complicated --- `CaseAltIndex` used in the surface IR. -nthCaseAltIdx :: CType n -> Int -> CaseAltIndex -nthCaseAltIdx ty i = case ty of - TypeCon _ _ _ -> i - _ -> error $ "can't pattern-match on: " <> pprint ty + emit $ ThrowError resultTy' + Just alt -> applyAbs alt (SubstVal v) >>= emit buildMonomorphicCase :: (Emits n, ScopableBuilder CoreIR m) @@ -1648,7 +1165,7 @@ buildSortedCase :: (Fallible1 m, Builder CoreIR m, Emits n) buildSortedCase scrut alts resultTy = do scrutTy <- return $ getType scrut case scrutTy of - TypeCon _ defName _ -> do + TyCon (NewtypeTyCon (UserADTType _ defName _)) -> do TyConDef _ _ _ (ADTCons cons) <- lookupTyCon defName case cons of [] -> error "case of void?" @@ -1656,91 +1173,72 @@ buildSortedCase scrut alts resultTy = do [_] -> do let [IndexedAlt _ alt] = alts scrut' <- unwrapNewtype scrut - emitBlock =<< applyAbs alt (SubstVal scrut') - _ -> liftEmitBuilder $ buildMonomorphicCase alts scrut resultTy + emit =<< applyAbs alt (SubstVal scrut') + _ -> do + scrut' <- unwrapNewtype scrut + liftEmitBuilder $ buildMonomorphicCase alts scrut' resultTy _ -> fail $ "Unexpected case expression type: " <> pprint scrutTy -- TODO: cache this with the instance def (requires a recursive binding) instanceFun :: EnvReader m => InstanceName n -> AppExplicitness -> m n (CAtom n) instanceFun instanceName appExpl = do InstanceDef _ expls bs _ _ <- lookupInstanceDef instanceName - ab <- liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do + liftEnvReaderM $ refreshAbs (Abs bs UnitE) \bs' UnitE -> do args <- mapM toAtomVar $ nestToNames bs' - result <- mkDictAtom $ InstanceDict (sink instanceName) (Var <$> args) - return $ Abs bs' (PairE Pure (WithoutDecls result)) - Lam <$> coreLamExpr appExpl (snd<$>expls) ab - -checkMaybeAnnExpr :: EmitsBoth o - => NameHint -> Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) -checkMaybeAnnExpr hint ty expr = confuseGHC >>= \_ -> case ty of - Nothing -> inferSigma hint expr - Just ty' -> checkSigma hint expr =<< zonk =<< checkUType ty' - -inferRole :: CType o -> Explicitness -> InfererM i o ParamRole -inferRole ty = \case - Inferred _ (Synth _) -> return DictParam - _ -> do - zonk ty >>= \case - TyKind -> return TypeParam - ty' -> isData ty' >>= \case - True -> return DataParam - -- TODO(dougalm): the `False` branch should throw an error but that's - -- currently too conservative. e.g. `data RangeFrom q:Type i:q = ...` - -- fails because `q` isn't data. We should be able to fix it once we - -- have a `Data a` class (see issue #680). - False -> return DataParam -{-# INLINE inferRole #-} - -inferTyConDef :: EmitsInf o => UDataDef i -> InfererM i o (TyConDef o) + result <- toAtom <$> mkInstanceDict (sink instanceName) (toAtom <$> args) + let effTy = EffTy Pure (getType result) + let piTy = CorePiType appExpl (snd<$>expls) bs' effTy + return $ toAtom $ CoreLamExpr piTy (LamExpr bs' $ Atom result) + +checkMaybeAnnExpr :: Emits o => Maybe (UType i) -> UExpr i -> InfererM i o (CAtom o) +checkMaybeAnnExpr ty expr = confuseGHC >>= \_ -> case ty of + Nothing -> bottomUp expr + Just ty' -> do + ty'' <- checkUType ty' + topDown ty'' expr + +inferTyConDef :: UDataDef i -> InfererM i o (TyConDef o) inferTyConDef (UDataDef tyConName paramBs dataCons) = do - Abs paramBs' dataCons' <- - withRoleUBinders paramBs do - ADTCons <$> mapM inferDataCon dataCons - let (roleExpls, paramBs'') = unzipAttrs paramBs' - return (TyConDef tyConName roleExpls paramBs'' dataCons') + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + dataCons' <- ADTCons <$> mapM inferDataCon dataCons + return (TyConDef tyConName roleExpls paramBs' dataCons') -inferStructDef :: EmitsInf o => UStructDef i -> InfererM i o (TyConDef o) +inferStructDef :: UStructDef i -> InfererM i o (TyConDef o) inferStructDef (UStructDef tyConName paramBs fields _) = do - let (fieldNames, fieldTys) = unzip fields - Abs paramBs' dataConDefs <- withRoleUBinders paramBs do + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + let (fieldNames, fieldTys) = unzip fields tys <- mapM checkUType fieldTys - return $ StructFields $ zip fieldNames tys - let (roleExpls, paramBs'') = unzipAttrs paramBs' - return $ TyConDef tyConName roleExpls paramBs'' dataConDefs + let dataConDefs = StructFields $ zip (withoutSrc <$> fieldNames) tys + return $ TyConDef tyConName roleExpls paramBs' dataConDefs inferDotMethod - :: EmitsInf o => TyConName o - -> Abs (Nest UOptAnnBinder) (Abs UAtomBinder ULamExpr) i + :: TyConName o + -> Abs (Nest UAnnBinder) (Abs UAtomBinder ULamExpr) i -> InfererM i o (CoreLamExpr o) inferDotMethod tc (Abs uparamBs (Abs selfB lam)) = do TyConDef sn roleExpls paramBs _ <- lookupTyCon tc let expls = snd <$> roleExpls - ab <- buildNaryAbsInfWithExpl expls (Abs paramBs UnitE) \paramVs -> do - let paramVs' = catMaybes $ zip expls paramVs <&> \(expl, v) -> case expl of - Inferred _ (Synth _) -> Nothing - _ -> Just v - extendRenamer (uparamBs @@> (atomVarName <$> paramVs')) do - let selfTy = NewtypeTyCon $ UserADTType sn (sink tc) (TyConParams expls (Var <$> paramVs)) - buildAbsInfWithExpl "self" Explicit selfTy \vSelf -> - extendRenamer (selfB @> atomVarName vSelf) $ inferULam lam - Abs paramBs'' (Abs selfB' lam') <- return ab - return $ prependCoreLamExpr (paramBs'' >>> UnaryNest selfB') lam' - -prependCoreLamExpr :: Nest (WithExpl CBinder) n l -> CoreLamExpr l -> CoreLamExpr n -prependCoreLamExpr bs e = case e of - CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do - let (expls, bs') = unzipAttrs bs - let piType = CorePiType appExpl (expls <> piExpls) (bs' >>> piBs) effTy - let lamExpr = LamExpr (fmapNest withoutAttr bs >>> lamBs) body - CoreLamExpr piType lamExpr - -inferDataCon :: EmitsInf o => (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) + withFreshBindersInf expls (Abs paramBs UnitE) \paramBs' UnitE -> do + let paramVs = bindersVars paramBs' + extendRenamer (uparamBs @@> (atomVarName <$> paramVs)) do + let selfTy = toType $ UserADTType sn (sink tc) (TyConParams expls (toAtom <$> paramVs)) + withFreshBinderInf "self" Explicit selfTy \selfB' -> do + lam' <- extendRenamer (selfB @> binderName selfB') $ inferULam lam + return $ prependCoreLamExpr (expls ++ [Explicit]) (paramBs' >>> UnaryNest selfB') lam' + + where + prependCoreLamExpr :: [Explicitness] -> Nest CBinder n l -> CoreLamExpr l -> CoreLamExpr n + prependCoreLamExpr expls bs e = case e of + CoreLamExpr (CorePiType appExpl piExpls piBs effTy) (LamExpr lamBs body) -> do + let piType = CorePiType appExpl (expls <> piExpls) (bs >>> piBs) effTy + let lamExpr = LamExpr (bs >>> lamBs) body + CoreLamExpr piType lamExpr + +inferDataCon :: (SourceName, UDataDefTrail i) -> InfererM i o (DataConDef o) inferDataCon (sourceName, UDataDefTrail argBs) = do - let expls = nestToList (const Explicit) argBs - Abs argBs' UnitE <- withUBinders (expls, argBs) \_ -> return UnitE - let argBs'' = Abs (fmapNest withoutAttr argBs') UnitE - let (repTy, projIdxs) = dataConRepTy argBs'' - return $ DataConDef sourceName argBs'' repTy projIdxs + withUBinders argBs \(ZipB _ argBs') -> do + let (repTy, projIdxs) = dataConRepTy $ EmptyAbs argBs' + return $ DataConDef sourceName (EmptyAbs argBs') repTy projIdxs dataConRepTy :: EmptyAbs (Nest CBinder) n -> (CType n, [[Projection]]) dataConRepTy (Abs topBs UnitE) = case topBs of @@ -1752,7 +1250,7 @@ dataConRepTy (Abs topBs UnitE) = case topBs of Empty -> case revAcc of [] -> error "should never happen" [ty] -> (ty, [projIdxs]) - _ -> ( ProdTy $ reverse revAcc + _ -> ( toType $ ProdType $ reverse revAcc , iota (length revAcc) <&> \i -> ProjectProduct i:projIdxs ) Nest b bs -> case hoist b (EmptyAbs bs) of HoistSuccess (Abs bs' UnitE) -> go (binderType b:revAcc) projIdxs bs' @@ -1761,195 +1259,194 @@ dataConRepTy (Abs topBs UnitE) = case topBs of accSize = length revAcc (fullTy, depTyIdxs) = case revAcc of [] -> (depTy, []) - _ -> (ProdTy $ reverse revAcc ++ [depTy], [ProjectProduct accSize]) + _ -> (toType $ ProdType $ reverse revAcc ++ [depTy], [ProjectProduct accSize]) (tailTy, tailIdxs) = go [] (ProjectProduct 1 : (depTyIdxs ++ projIdxs)) bs idxs = (iota accSize <&> \i -> ProjectProduct i : projIdxs) ++ ((ProjectProduct 0 : (depTyIdxs ++ projIdxs)) : tailIdxs) - depTy = DepPairTy $ DepPairType ExplicitDepPair b tailTy + depTy = toType $ DepPairTy $ DepPairType ExplicitDepPair b tailTy inferClassDef - :: EmitsInf o - => SourceName -> [SourceName] - -> UOptAnnExplBinders i i' - -> [UType i'] + :: SourceName -> [SourceName] -> Nest UAnnBinder i i' -> [UType i'] -> InfererM i o (ClassDef o) -inferClassDef className methodNames paramBs@(expls, paramBs') methods = do - let paramBsWithAttrBs = zipWithNest paramBs' expls \b expl -> WithAttrB expl b - let paramNames = catMaybes $ nestToList - (\(WithAttrB expl (UAnnBinder b _ _)) -> case expl of - Inferred _ (Synth _) -> Nothing - _ -> Just $ Just $ getSourceName b) paramBsWithAttrBs - ab <- withRoleUBinders paramBs do - ListE <$> forM methods \m -> do - checkUType m >>= \case - Pi t -> return t - t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) - Abs (PairB bs scs) (ListE mtys) <- identifySuperclasses ab - let (roleExpls, bs') = unzipAttrs bs - return $ ClassDef className methodNames paramNames roleExpls bs' scs mtys - -identifySuperclasses - :: RenameE e => Abs (Nest (WithRoleExpl CBinder)) e n - -> InfererM i n (Abs (PairB (Nest (WithRoleExpl CBinder)) (Nest CBinder)) e n) -identifySuperclasses ab = do - refreshAbs ab \bs e -> do - bs' <- partitionBinders bs \b@(WithAttrB (_, expl) b') -> case expl of - Explicit -> return $ LeftB b - Inferred _ Unify -> throw TypeErr "Interfaces can't have implicit parameters" - Inferred _ (Synth _) -> return $ RightB b' - return $ Abs bs' e - -withUBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => UAnnExplBinders req i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -withUBinders bs cont = case bs of - ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont [] - (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do - ann' <- checkAnn (getSourceName b) ann - prependAbs <$> buildAbsInfWithExpl (getNameHint b) expl ann' \v -> - concatAbs <$> withConstraintBinders cs v do - extendSubst (b@>sink (atomVarName v)) $ withUBinders (expls, rest) \vs -> - cont (sink v : vs) - _ -> error "zip error" - -withConstraintBinders - :: (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, RenameE e, SinkableE e) - => [UConstraint i] - -> CAtomVar o - -> (forall o'. (EmitsInf o', DExt o o') => InfererM i o' (e o')) +inferClassDef className methodNames paramBs methodTys = do + withRoleUBinders paramBs \(ZipB roleExpls paramBs') -> do + let paramNames = catMaybes $ nestToListFlip paramBs \(UAnnBinder expl b _ _) -> + case expl of Inferred _ (Synth _) -> Nothing + _ -> Just $ Just $ getSourceName b + methodTys' <- forM methodTys \m -> do + checkUType m >>= \case + TyCon (Pi t) -> return t + t -> return $ CorePiType ImplicitApp [] Empty (EffTy Pure t) + PairB paramBs'' superclassBs <- partitionBinders rootSrcId (zipAttrs roleExpls paramBs') $ + \b@(WithAttrB (_, expl) b') -> case expl of + Explicit -> return $ LeftB b + -- TODO: Add a proper SrcId here. We'll need to plumb it through from the original UBinders + Inferred _ Unify -> throw rootSrcId InterfacesNoImplicitParams + Inferred _ (Synth _) -> return $ RightB b' + let (roleExpls', paramBs''') = unzipAttrs paramBs'' + builtinName <- case className of + -- TODO: this is hacky. Let's just make the Ix class, including its + -- methods, fully built-in instead of prelude-defined. + "Ix" -> return $ Just Ix + "Data" -> return $ Just Data + _ -> return Nothing + return $ ClassDef className builtinName methodNames paramNames roleExpls' paramBs''' superclassBs methodTys' + +withUBinder :: UAnnBinder i i' -> InfererCPSB2 (WithExpl CBinder) i i' o a +withUBinder (UAnnBinder expl b ann cs) cont = do + ty <- inferAnn (getSrcId b) ann cs + withFreshBinderInf (getNameHint b) expl ty \b' -> + extendSubst (b@>binderName b') $ cont (WithAttrB expl b') + +withUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithExpl CBinder)) i i' o a +withUBinders bs cont = do + Abs bs' UnitE <- inferUBinders bs \_ -> return UnitE + let (expls, bs'') = unzipAttrs bs' + withFreshBindersInf expls (Abs bs'' UnitE) \bs''' UnitE -> do + extendSubst (bs@@> (atomVarName <$> bindersVars bs''')) $ + cont $ zipAttrs expls bs''' + +inferUBinders + :: Zonkable e => Nest UAnnBinder i i' + -> (forall o'. DExt o o' => [CAtomName o'] -> InfererM i' o' (e o')) -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -withConstraintBinders [] _ cont = getDistinct >>= \Distinct -> Abs Empty <$> cont -withConstraintBinders (c:cs) v cont = do - Type dictTy <- withReducibleEmissions "Can't reduce interface constraint" do - c' <- inferWithoutInstantiation c >>= zonk - dropSubst $ checkOrInferApp c' [Var $ sink v] [] (Check TyKind) - prependAbs <$> buildAbsInfWithExpl "d" (Inferred Nothing (Synth Full)) dictTy \_ -> - withConstraintBinders cs (sink v) cont - -withRoleUBinders - :: forall i i' o e req. (EmitsInf o, HasNamesE e, SubstE AtomSubstVal e, SinkableE e) - => UAnnExplBinders req i i' - -> (forall o'. (EmitsInf o', DExt o o') => InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithRoleExpl CBinder)) e o) -withRoleUBinders roleBs cont = case roleBs of - ([], Empty) -> getDistinct >>= \Distinct -> Abs Empty <$> cont - (expl:expls, Nest (UAnnBinder b ann cs) rest) -> do - ann' <- checkAnn (getSourceName b) ann - Abs b' (Abs bs' e) <- buildAbsInf (getNameHint b) expl ann' \v -> do - Abs ds (Abs bs' e) <- withConstraintBinders cs v $ - extendSubst (b@>sink (atomVarName v)) $ withRoleUBinders (expls, rest) cont - let ds' = fmapNest (\(WithAttrB expl' b') -> WithAttrB (DictParam, expl') b') ds - return $ Abs (ds' >>> bs') e - role <- inferRole (binderType b') expl - return $ Abs (Nest (WithAttrB (role,expl) b') bs') e - _ -> error "zip error" - -inferULam :: EmitsInf o => ULamExpr i -> InfererM i o (CoreLamExpr o) +inferUBinders Empty cont = withDistinct $ Abs Empty <$> cont [] +inferUBinders (Nest (UAnnBinder expl b ann cs) bs) cont = do + -- TODO: factor out the common part of each case (requires an annotated + -- `where` clause because of the rank-2 type) + ty <- inferAnn (getSrcId b) ann cs + withFreshBinderInf (getNameHint b) expl ty \b' -> do + extendSubst (b@>binderName b') do + Abs bs' e <- inferUBinders bs \vs -> cont (sink (binderName b') : vs) + return $ Abs (Nest (WithAttrB expl b') bs') e + +withRoleUBinders :: Nest UAnnBinder i i' -> InfererCPSB2 (Nest (WithRoleExpl CBinder)) i i' o a +withRoleUBinders bs cont = do + withUBinders bs \(ZipB expls bs') -> do + let tys = getType <$> bindersVars bs' + roleExpls <- forM (zip tys expls) \(ty, expl) -> do + role <- inferRole ty expl + return (role, expl) + cont (zipAttrs roleExpls bs') + where + inferRole :: CType o -> Explicitness -> InfererM i o ParamRole + inferRole ty = \case + Inferred _ (Synth _) -> return DictParam + _ -> case ty of + TyKind -> return TypeParam + _ -> isData ty >>= \case + True -> return DataParam + -- TODO(dougalm): the `False` branch should throw an error but that's + -- currently too conservative. e.g. `data RangeFrom q:Type i:q = ...` + -- fails because `q` isn't data. We should be able to fix it once we + -- have a `Data a` class (see issue #680). + False -> return DataParam + {-# INLINE inferRole #-} + +inferAnn :: SrcId -> UAnn i -> [UConstraint i] -> InfererM i o (CType o) +inferAnn binderSrcId ann cs = case ann of + UAnn ty -> checkUType ty + UNoAnn -> case cs of + WithSrcE sid (UVar ~(InternalName _ _ v)):_ -> do + renameM v >>= getUVarType >>= \case + TyCon (Pi (CorePiType ExplicitApp [Explicit] (UnaryNest (_:>ty)) _)) -> return ty + ty -> throw sid $ NotAUnaryConstraint $ pprint ty + _ -> throw binderSrcId AnnotationRequired + +checkULamPartial :: PartialPiType o -> SrcId -> ULamExpr i -> InfererM i o (CoreLamExpr o) +checkULamPartial partialPiTy sid lamExpr = do + PartialPiType piAppExpl expls piBs piEffs piReqTy <- return partialPiTy + ULamExpr lamBs lamAppExpl lamEffs lamResultTy body <- return lamExpr + checkExplicitArity sid expls (nestToList (const ()) lamBs) + when (piAppExpl /= lamAppExpl) $ throw sid $ WrongArrowErr (pprint piAppExpl) (pprint lamAppExpl) + checkLamBinders expls piBs lamBs \lamBs' -> do + PairE piEffs' piReqTy' <- applyRename (piBs @@> (atomVarName <$> bindersVars lamBs')) (PairE piEffs piReqTy) + resultTy <- case (lamResultTy, piReqTy') of + (Nothing, Infer ) -> return Infer + (Just t , Infer ) -> Check <$> checkUType t + (Nothing, Check t) -> Check <$> return t + (Just t , Check t') -> checkUType t >>= expectEq (getSrcId t) t' >> return (Check t') + forM_ lamEffs \lamEffs' -> do + lamEffs'' <- checkUEffRow lamEffs' + expectEq sid (Eff piEffs') (Eff lamEffs'') -- TODO: add source annotations to lambda effects too + body' <- withAllowedEffects piEffs' do + buildBlock $ withBlockDecls body \result -> checkOrInfer (sink resultTy) result + resultTy' <- case resultTy of + Infer -> return $ getType body' + Check t -> return t + let piTy = CorePiType piAppExpl expls lamBs' (EffTy piEffs' resultTy') + return $ CoreLamExpr piTy (LamExpr lamBs' body') + where + checkLamBinders + :: [Explicitness] -> Nest CBinder o any -> Nest UAnnBinder i i' + -> InfererCPSB2 (Nest CBinder) i i' o a + checkLamBinders [] Empty Empty cont = withDistinct $ cont Empty + checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do + case piExpl of + Inferred _ _ -> do + withFreshBinderInf (getNameHint piB) piExpl piAnn \b -> do + Abs piBs' UnitE <- applyRename (piB@>binderName b) (EmptyAbs piBs) + checkLamBinders piExpls piBs' lamBs \bs -> cont (Nest b bs) + Explicit -> case lamBs of + Nest (UAnnBinder _ lamB lamAnn _) lamBsRest -> do + case lamAnn of + UAnn lamAnn' -> checkUType lamAnn' >>= expectEq (getSrcId lamAnn') piAnn + UNoAnn -> return () + withFreshBinderInf (getNameHint lamB) Explicit piAnn \b -> do + Abs piBs' UnitE <- applyRename (piB@>binderName b) (EmptyAbs piBs) + extendRenamer (lamB@>sink (binderName b)) $ + checkLamBinders piExpls piBs' lamBsRest \bs -> cont (Nest b bs) + Empty -> error "zip error" + checkLamBinders _ _ _ _ = error "zip error" + +inferUForExpr :: Emits o => UForExpr i -> InfererM i o (LamExpr CoreIR o) +inferUForExpr (UForExpr b body) = do + withUBinder b \(WithAttrB _ b') -> do + body' <- buildBlock $ withBlockDecls body \result -> bottomUp result + return $ LamExpr (UnaryNest b') body' + +checkUForExpr :: Emits o => SrcId -> UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) +checkUForExpr sid (UForExpr bFor body) (TabPiType _ bPi resultTy) = do + let uLamExpr = ULamExpr (UnaryNest bFor) ExplicitApp Nothing Nothing body + effsAllowed <- infEffects <$> getInfState + partialPi <- liftEnvReaderM $ refreshAbs (Abs bPi resultTy) \bPi' resultTy' -> do + return $ PartialPiType ExplicitApp [Explicit] (UnaryNest bPi') (sink effsAllowed) (Check resultTy') + CoreLamExpr _ lamExpr <- checkULamPartial partialPi sid uLamExpr + return lamExpr + +inferULam :: ULamExpr i -> InfererM i o (CoreLamExpr o) inferULam (ULamExpr bs appExpl effs resultTy body) = do - ab <- withUBinders bs \_ -> do + Abs (ZipB expls bs') (PairE effTy body') <- inferUBinders bs \_ -> do effs' <- fromMaybe Pure <$> mapM checkUEffRow effs resultTy' <- mapM checkUType resultTy - body' <- buildBlockInf $ withAllowedEffects (sink effs') do - case resultTy' of - Nothing -> withBlockDecls body \result -> inferSigma noHint result - Just resultTy'' -> - withBlockDecls body \result -> - checkSigma noHint result (sink resultTy'') - return (PairE effs' body') - Abs bs' (PairE effs' body') <- return ab - let (expls, bs'') = unzipAttrs bs' - case appExpl of - ImplicitApp -> checkImplicitLamRestrictions bs'' effs' - ExplicitApp -> return () - coreLamExpr appExpl expls $ Abs bs'' $ PairE effs' body' - -checkImplicitLamRestrictions :: Nest CBinder o o' -> EffectRow CoreIR o' -> InfererM i o () -checkImplicitLamRestrictions _ _ = return () -- TODO - -checkUForExpr :: EmitsBoth o => UForExpr i -> TabPiType CoreIR o -> InfererM i o (LamExpr CoreIR o) -checkUForExpr (UForExpr (UAnnBinder bFor ann cs) body) tabPi@(TabPiType _ bPi _) = do - unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - let iTy = binderAnn bPi - case ann of - UNoAnn -> return () - UAnn forAnn -> checkUType forAnn >>= constrainTypesEq iTy - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> do - extendRenamer (bFor@>atomVarName i) do - TabPiType _ bPi' resultTy <- sinkM tabPi - resultTy' <- applyRename (bPi'@>atomVarName i) resultTy - buildBlockInf do - withBlockDecls body \result -> - checkSigma noHint result $ sink resultTy' - return $ LamExpr (UnaryNest b) body' - -inferUForExpr :: EmitsBoth o => UForExpr i -> InfererM i o (LamExpr CoreIR o) -inferUForExpr (UForExpr (UAnnBinder bFor ann cs) body) = do - unless (null cs) $ throw TypeErr "`for` binders shouldn't have constraints" - iTy <- checkAnn (getSourceName bFor) ann - Abs b body' <- buildAbsInf (getNameHint bFor) Explicit iTy \i -> - extendRenamer (bFor@>atomVarName i) $ buildBlockInf $ + body' <- buildBlock $ withAllowedEffects (sink effs') do withBlockDecls body \result -> - checkOrInferRho noHint result Infer - return $ LamExpr (UnaryNest b) body' - -checkULam :: EmitsInf o => ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) -checkULam (ULamExpr (_, lamBs) lamAppExpl lamEffs lamResultTy body) - (CorePiType piAppExpl expls piBs effTy) = do - checkArity expls (nestToList (const ()) lamBs) - when (piAppExpl /= lamAppExpl) $ throw TypeErr $ "Wrong arrow. Expected " - ++ pprint piAppExpl ++ " got " ++ pprint lamAppExpl - Abs explBs body' <- checkLamBinders expls piBs lamBs \vs -> do - EffTy piEffs' piResultTy' <- applyRename (piBs@@>map atomVarName vs) effTy - case lamResultTy of - Nothing -> return () - Just t -> checkUType t >>= constrainTypesEq piResultTy' - forM_ lamEffs \lamEffs' -> do - lamEffs'' <- checkUEffRow lamEffs' - constrainEq (Eff piEffs') (Eff lamEffs'') - withAllowedEffects piEffs' do - body' <- buildBlockInf do - piResultTy'' <- sinkM piResultTy' - withBlockDecls body \result -> - checkSigma noHint result piResultTy'' - return $ PairE piEffs' body' - let (expls', bs') = unzipAttrs explBs - coreLamExpr piAppExpl expls' $ Abs bs' body' - -checkLamBinders - :: (EmitsInf o, SinkableE e, HoistableE e, SubstE AtomSubstVal e, RenameE e) - => [Explicitness] -> Nest CBinder o any - -> Nest UOptAnnBinder i i' - -> (forall o'. (EmitsInf o', DExt o o') => [CAtomVar o'] -> InfererM i' o' (e o')) - -> InfererM i o (Abs (Nest (WithExpl CBinder)) e o) -checkLamBinders [] Empty Empty cont = getDistinct >>= \Distinct -> Abs Empty <$> cont [] -checkLamBinders (piExpl:piExpls) (Nest (piB:>piAnn) piBs) lamBs cont = do - prependAbs <$> case piExpl of - Inferred _ _ -> - buildAbsInfWithExpl (getNameHint piB) piExpl piAnn \v -> do - Abs piBs' UnitE <- applyRename (piB@>atomVarName v) $ Abs piBs UnitE - checkLamBinders piExpls piBs' lamBs \vs -> - cont (sink v:vs) - Explicit -> case lamBs of - Nest (UAnnBinder lamB ann cs) lamBsRest -> do - case ann of - UAnn lamAnn -> checkUType lamAnn >>= constrainTypesEq piAnn - UNoAnn -> return () - buildAbsInfWithExpl (getNameHint lamB) Explicit piAnn \v -> do - concatAbs <$> withConstraintBinders cs v do - Abs piBs' UnitE <- applyRename (piB@>sink (atomVarName v)) $ Abs piBs UnitE - extendRenamer (lamB@>sink (atomVarName v)) $ checkLamBinders piExpls piBs' lamBsRest \vs -> - cont (sink v:vs) - Empty -> error "zip error" -checkLamBinders _ _ _ _ = error "zip error" - -checkInstanceParams :: EmitsInf o => [Explicitness] -> Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] -checkInstanceParams expls bsTop paramsTop = do - checkArity expls paramsTop - go bsTop paramsTop + case resultTy' of + Nothing -> bottomUp result + Just resultTy'' -> topDown (sink resultTy'') result + let effTy = EffTy effs' (getType body') + return $ PairE effTy body' + return $ CoreLamExpr (CorePiType appExpl expls bs' effTy) (LamExpr bs' body') + +checkULam :: SrcId -> ULamExpr i -> CorePiType o -> InfererM i o (CoreLamExpr o) +checkULam sid ulam piTy = checkULamPartial (piAsPartialPi piTy) sid ulam + +piAsPartialPi :: CorePiType n -> PartialPiType n +piAsPartialPi (CorePiType appExpl expls bs (EffTy effs ty)) = + PartialPiType appExpl expls bs effs (Check ty) + +typeAsPartialType :: CType n -> PartialType n +typeAsPartialType (TyCon (Pi piTy)) = PartialType $ piAsPartialPi piTy +typeAsPartialType ty = FullType ty + +piAsPartialPiDropResultTy :: CorePiType n -> PartialPiType n +piAsPartialPiDropResultTy (CorePiType appExpl expls bs (EffTy effs _)) = + PartialPiType appExpl expls bs effs Infer + +checkInstanceParams :: Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] +checkInstanceParams bsTop paramsTop = go bsTop paramsTop where - go :: EmitsInf o => Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] + go :: Nest CBinder o any -> [UExpr i] -> InfererM i o [CAtom o] go Empty [] = return [] go (Nest (b:>ty) bs) (x:xs) = do x' <- checkUParam ty x @@ -1958,415 +1455,224 @@ checkInstanceParams expls bsTop paramsTop = do go _ _ = error "zip error" checkInstanceBody - :: EmitsInf o => ClassName o -> [CAtom o] + :: ClassName o -> [CAtom o] -> [UMethodDef i] -> InfererM i o (InstanceBody o) checkInstanceBody className params methods = do - ClassDef _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className + -- instances are top-level so it's ok to have imprecise root srcIds here + let sid = rootSrcId + ClassDef _ _ methodNames _ _ paramBs scBs methodTys <- lookupClassDef className Abs scBs' methodTys' <- applySubst (paramBs @@> (SubstVal <$> params)) $ Abs scBs $ ListE methodTys superclassTys <- superclassDictTys scBs' - superclassDicts <- mapM (flip trySynthTerm Full) superclassTys + superclassDicts <- mapM (flip (trySynthTerm sid) Full) superclassTys ListE methodTys'' <- applySubst (scBs'@@>(SubstVal<$>superclassDicts)) methodTys' methodsChecked <- mapM (checkMethodDef className methodTys'') methods let (idxs, methods') = unzip $ sortOn fst $ methodsChecked forM_ (repeated idxs) \i -> - throw TypeErr $ "Duplicate method: " ++ pprint (methodNames!!i) - forM_ ([0..(length methodTys'' - 1)] `listDiff` idxs) \i -> - throw TypeErr $ "Missing method: " ++ pprint (methodNames!!i) + throw sid $ DuplicateMethod $ pprint (methodNames!!i) + forM_ ([0..(length methodTys''-1)] `listDiff` idxs) \i -> + throw sid $ MissingMethod $ pprint (methodNames!!i) return $ InstanceBody superclassDicts methods' superclassDictTys :: Nest CBinder o o' -> InfererM i o [CType o] superclassDictTys Empty = return [] superclassDictTys (Nest b bs) = do - Abs bs' UnitE <- liftHoistExcept $ hoist b $ Abs bs UnitE + Abs bs' UnitE <- liftHoistExcept rootSrcId $ hoist b $ Abs bs UnitE (binderType b:) <$> superclassDictTys bs' -checkMethodDef :: EmitsInf o - => ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) -checkMethodDef className methodTys (WithSrcE src m) = addSrcContext src do +checkMethodDef :: ClassName o -> [CorePiType o] -> UMethodDef i -> InfererM i o (Int, CAtom o) +checkMethodDef className methodTys (WithSrcE sid m) = do UMethodDef ~(InternalName _ sourceName v) rhs <- return m MethodBinding className' i <- renameM v >>= lookupEnv when (className /= className') do - ClassBinding (ClassDef classSourceName _ _ _ _ _ _) <- lookupEnv className - throw TypeErr $ pprint sourceName ++ " is not a method of " ++ pprint classSourceName - (i,) <$> Lam <$> checkULam rhs (methodTys !! i) + ClassBinding classDef <- lookupEnv className + throw sid $ NotAMethod (pprint sourceName) (pprint $ getSourceName classDef) + (i,) <$> toAtom <$> Lam <$> checkULam sid rhs (methodTys !! i) -checkUEffRow :: EmitsInf o => UEffectRow i -> InfererM i o (EffectRow CoreIR o) +checkUEffRow :: UEffectRow i -> InfererM i o (EffectRow CoreIR o) checkUEffRow (UEffectRow effs t) = do effs' <- liftM eSetFromList $ mapM checkUEff $ toList effs t' <- case t of Nothing -> return NoTail - Just (~(SIInternalName _ v _ _)) -> do + Just (SourceOrInternalName ~(InternalName sid _ v)) -> do v' <- toAtomVar =<< renameM v - constrainVarTy v' EffKind + expectEq sid EffKind (getType v') return $ EffectRowTail v' return $ EffectRow effs' t' -checkUEff :: EmitsInf o => UEffect i -> InfererM i o (Effect CoreIR o) +checkUEff :: UEffect i -> InfererM i o (Effect CoreIR o) checkUEff eff = case eff of - URWSEffect rws (~(SIInternalName _ region _ _)) -> do + URWSEffect rws (SourceOrInternalName ~(InternalName sid _ region)) -> do region' <- renameM region >>= toAtomVar - constrainVarTy region' (TC HeapType) - return $ RWSEffect rws (Var region') + expectEq sid (TyCon HeapType) (getType region') + return $ RWSEffect rws (toAtom region') UExceptionEffect -> return ExceptionEffect UIOEffect -> return IOEffect -constrainVarTy :: EmitsInf o => CAtomVar o -> CType o -> InfererM i o () -constrainVarTy v tyReq = do - varTy <- return $ getType $ Var v - constrainTypesEq tyReq varTy - type CaseAltIndex = Int -checkCaseAlt :: EmitsBoth o - => CType o -> CType o -> UAlt i -> InfererM i o (IndexedAlt o) +checkCaseAlt :: Emits o => RequiredTy o -> CType o -> UAlt i -> InfererM i o (IndexedAlt o) checkCaseAlt reqTy scrutineeTy (UAlt pat body) = do alt <- checkCasePat pat scrutineeTy do - reqTy' <- sinkM reqTy - withBlockDecls body \result -> - checkOrInferRho noHint result (Check reqTy') + withBlockDecls body \result -> checkOrInfer (sink reqTy) result idx <- getCaseAltIndex pat return $ IndexedAlt idx alt getCaseAltIndex :: UPat i i' -> InfererM i o CaseAltIndex -getCaseAltIndex (WithSrcB _ pat) = case pat of +getCaseAltIndex (WithSrcB sid pat) = case pat of UPatCon ~(InternalName _ _ conName) _ -> do (_, con) <- renameM conName >>= lookupDataCon return con - _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" - -checkCasePat :: EmitsBoth o - => UPat i i' - -> CType o - -> (forall o'. (EmitsBoth o', Ext o o') => InfererM i' o' (CAtom o')) - -> InfererM i o (Alt CoreIR o) -checkCasePat (WithSrcB pos pat) scrutineeTy cont = addSrcContext pos $ case pat of + _ -> throw sid IllFormedCasePattern + +checkCasePat + :: Emits o + => UPat i i' -> CType o + -> (forall o'. (Emits o', Ext o o') => InfererM i' o' (CAtom o')) + -> InfererM i o (Alt CoreIR o) +checkCasePat (WithSrcB sid pat) scrutineeTy cont = case pat of UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, con) <- renameM conName >>= lookupDataCon - TyConDef sourceName roleExpls paramBs (ADTCons cons) <- lookupTyCon dataDefName + tyConDef <- lookupTyCon dataDefName + params <- inferParams sid scrutineeTy dataDefName + ADTCons cons <- instantiateTyConDef tyConDef params DataConDef _ _ repTy idxs <- return $ cons !! con - when (length idxs /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxs) - ++ " got " ++ show (nestLength ps) - (params, repTy') <- inferParams sourceName roleExpls (Abs paramBs repTy) - constrainTypesEq scrutineeTy $ TypeCon sourceName dataDefName params - buildAltInf repTy' \arg -> do - args <- forM idxs \projs -> do - ans <- normalizeNaryProj (init projs) (Var arg) - emit $ Atom ans - bindLetPats ps args $ cont - _ -> throw TypeErr $ "Case patterns must start with a data constructor or variant pattern" - -inferParams :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => SourceName -> [RoleExpl] -> Abs (Nest CBinder) e o -> InfererM i o (TyConParams o, e o) -inferParams sourceName roleExpls (Abs paramBs bodyTop) = do - let expls = snd <$> roleExpls - (params, e') <- go expls (Abs paramBs bodyTop) - return (TyConParams expls params, e') - where - go :: (EmitsBoth o, HasNamesE e, SinkableE e, SubstE AtomSubstVal e) - => [Explicitness] -> Abs (Nest CBinder) e o -> InfererM i o ([CAtom o], e o) - go [] (Abs Empty body) = return ([], body) - go (expl:expls) (Abs (Nest (b:>ty) bs) body) = do - x <- case expl of - Explicit -> Var <$> freshInferenceName (TypeInstantiationInfVar sourceName) ty - Inferred argName infMech -> getImplicitArg (sourceName, fromMaybe "_" argName) infMech ty - rest <- applySubst (b@>SubstVal x) $ Abs bs body - (params, body') <- go expls rest - return (x:params, body') - go _ _ = error "zip error" - -bindLetPats :: EmitsBoth o - => Nest UPat i i' -> [CAtomVar o] -> InfererM i' o a -> InfererM i o a -bindLetPats Empty [] cont = cont -bindLetPats (Nest p ps) (x:xs) cont = bindLetPat p x $ bindLetPats ps xs cont + when (length idxs /= nestLength ps) $ throw sid $ PatternArityErr (length idxs) (nestLength ps) + withFreshBinderInf noHint Explicit repTy \b -> Abs b <$> do + buildBlock do + args <- forM idxs \projs -> do + emitToVar =<< applyProjectionsReduced (init projs) (sink $ toAtom $ binderVar b) + bindLetPats ps args $ cont + _ -> throw sid IllFormedCasePattern + +inferParams :: Emits o => SrcId -> CType o -> TyConName o -> InfererM i o (TyConParams o) +inferParams sid ty dataDefName = do + TyConDef sourceName roleExpls paramBs _ <- lookupTyCon dataDefName + let paramExpls = snd <$> roleExpls + let inferenceExpls = paramExpls <&> \case + Explicit -> Inferred Nothing Unify + expl -> expl + paramBsAbs <- buildConstraints (Abs paramBs UnitE) \params _ -> do + let ty' = toType $ UserADTType sourceName (sink dataDefName) $ TyConParams paramExpls params + return [TypeConstraint sid (sink ty) ty'] + args <- inferMixedArgs sid sourceName inferenceExpls paramBsAbs emptyMixedArgs + return $ TyConParams paramExpls args + +bindLetPats + :: (Emits o, HasNamesE e) + => Nest UPat i i' -> [CAtomVar o] + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) +bindLetPats Empty [] cont = getDistinct >>= \Distinct -> cont +bindLetPats (Nest p ps) (x:xs) cont = bindLetPat p x $ bindLetPats ps (sink <$> xs) cont bindLetPats _ _ _ = error "mismatched number of args" -bindLetPat :: EmitsBoth o => UPat i i' -> CAtomVar o -> InfererM i' o a -> InfererM i o a -bindLetPat (WithSrcB pos pat) v cont = addSrcContext pos $ case pat of - UPatBinder b -> extendSubst (b @> atomVarName v) cont +bindLetPat + :: (Emits o, HasNamesE e) + => UPat i i' -> CAtomVar o + -> (forall o'. (Emits o', DExt o o') => InfererM i' o' (e o')) + -> InfererM i o (e o) +bindLetPat (WithSrcB sid pat) v cont = case pat of + UPatBinder b -> getDistinct >>= \Distinct -> extendSubst (b @> atomVarName v) cont UPatProd ps -> do let n = nestLength ps - ty <- return $ getType v - _ <- fromProdType n ty - x <- zonk $ Var v - xs <- forM (iota n) \i -> do - normalizeProj (ProjectProduct i) x >>= emit . Atom + case getType v of + TyCon (ProdType ts) | length ts == n -> return () + ty -> throw sid $ PatTypeErr "product type" (pprint ty) + xs <- forM (iota n) \i -> proj i (toAtom v) >>= emitInline bindLetPats ps xs cont UPatDepPair (PairB p1 p2) -> do - let x = Var v - ty <- return $ getType x - _ <- fromDepPairType ty - x' <- zonk x -- ensure it has a dependent pair type before unpacking - x1 <- getFst x' >>= zonk >>= emit . Atom + case getType v of + TyCon (DepPairTy _) -> return () + ty -> throw sid $ PatTypeErr "dependent pair" (pprint ty) + -- XXX: we're careful here to reduce the projection because of the dependent + -- types. We do the same in the `UPatCon` case. + x1 <- reduceProj 0 (toAtom v) >>= emitInline bindLetPat p1 x1 do - x2 <- getSnd x' >>= zonk >>= emit . Atom + x2 <- getSnd (sink $ toAtom v) >>= emitInline bindLetPat p2 x2 do cont UPatCon ~(InternalName _ _ conName) ps -> do (dataDefName, _) <- lookupDataCon =<< renameM conName - TyConDef sourceName roleExpls paramBs cons <- lookupTyCon dataDefName + TyConDef _ _ _ cons <- lookupTyCon dataDefName case cons of ADTCons [DataConDef _ _ _ idxss] -> do - when (length idxss /= nestLength ps) $ throw TypeErr $ - "Unexpected number of pattern binders. Expected " ++ show (length idxss) - ++ " got " ++ show (nestLength ps) - (params, UnitE) <- inferParams sourceName roleExpls (Abs paramBs UnitE) - constrainVarTy v $ TypeCon sourceName dataDefName params - x <- cheapNormalize =<< zonk (Var v) - xs <- forM idxss \idxs -> normalizeNaryProj idxs x >>= emit . Atom + when (length idxss /= nestLength ps) $ + throw sid $ PatternArityErr (length idxss) (nestLength ps) + void $ inferParams sid (getType $ toAtom v) dataDefName + xs <- forM idxss \idxs -> applyProjectionsReduced idxs (toAtom v) >>= emitInline bindLetPats ps xs cont - _ -> throw TypeErr $ "sum type constructor in can't-fail pattern" + _ -> throw sid SumTypeCantFail UPatTable ps -> do - elemTy <- freshType let n = fromIntegral (nestLength ps) :: Word32 - let iTy = FinConst n - idxTy <- asIxType iTy - ty <- return $ getType $ Var v - constrainTypesEq ty (idxTy ==> elemTy) - v' <- zonk $ Var v + case getType v of + TyCon (TabPi (TabPiType _ (_:>FinConst n') _)) | n == n' -> return () + ty -> throw sid $ PatTypeErr ("Fin " ++ show n ++ " table") (pprint ty) xs <- forM [0 .. n - 1] \i -> do - emit =<< mkTabApp v' [NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)] + emitToVar =<< mkTabApp (toAtom v) (toAtom $ NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) bindLetPats ps xs cont + where + emitInline :: Emits n => CAtom n -> InfererM i n (AtomVar CoreIR n) + emitInline atom = emitDecl noHint InlineLet $ Atom atom -checkAnn :: EmitsInf o => SourceName -> UAnn req i -> InfererM i o (CType o) -checkAnn binderSourceName ann = case ann of - UAnn ty -> checkUType ty - UNoAnn -> do - let desc = AnnotationInfVar binderSourceName - TyVar <$> freshInferenceName desc TyKind - -checkUType :: EmitsInf o => UType i -> InfererM i o (CType o) +checkUType :: UType i -> InfererM i o (CType o) checkUType t = do - Type t' <- checkUParam TyKind t + Just t' <- toMaybeType <$> checkUParam TyKind t return t' -checkUParam :: EmitsInf o => Kind CoreIR o -> UType i -> InfererM i o (CAtom o) -checkUParam k uty@(WithSrcE pos _) = addSrcContext pos $ - withReducibleEmissions msg $ withoutEffects $ checkRho noHint uty (sink k) - where msg = "Can't reduce type expression: " ++ pprint uty +checkUParam :: Kind CoreIR o -> UType i -> InfererM i o (CAtom o) +checkUParam k uty = + withReducibleEmissions (getSrcId uty) msg $ withAllowedEffects Pure $ topDownExplicit (sink k) uty + where msg = CantReduceType $ pprint uty -inferTabCon :: forall i o. EmitsBoth o - => NameHint -> [UExpr i] -> RequiredTy CType o -> InfererM i o (CAtom o) -inferTabCon hint xs reqTy = do +inferTabCon :: forall i o. Emits o => SrcId -> [UExpr i] -> InfererM i o (CAtom o) +inferTabCon sid xs = do + let n = fromIntegral (length xs) :: Word32 + let finTy = FinConst n + elemTy <- case xs of + [] -> throw sid InferEmptyTable + x:_ -> getType <$> bottomUp x + ixTy <- asIxType sid finTy + let tabTy = ixTy ==> elemTy + xs' <- forM xs \x -> topDown elemTy x + let dTy = toType $ DataDictType elemTy + Just dataDict <- toMaybeDict <$> trySynthTerm sid dTy Full + emit $ TabCon (Just $ WhenIRE dataDict) tabTy xs' + +checkTabCon :: forall i o. Emits o => TabPiType CoreIR o -> SrcId -> [UExpr i] -> InfererM i o (CAtom o) +checkTabCon tabTy@(TabPiType _ b elemTy) sid xs = do let n = fromIntegral (length xs) :: Word32 let finTy = FinConst n - ctx <- srcPosCtx <$> getErrCtx - let dataDictHole dTy = Just $ WhenIRE $ DictHole (AlwaysEqual ctx) dTy Full - case reqTy of - Infer -> do - elemTy <- case xs of - [] -> freshType - (x:_) -> getType <$> inferRho noHint x - ixTy <- asIxType finTy - let tabTy = ixTy ==> elemTy - xs' <- forM xs \x -> checkRho noHint x elemTy - dTy <- DictTy <$> dataDictType elemTy - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - Check tabTy -> do - TabPiType _ b elemTy <- fromTabPiType True tabTy - constrainTypesEq (binderType b) finTy - xs' <- forM (enumerate xs) \(i, x) -> do - let i' = NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i) :: CAtom o - elemTy' <- applySubst (b@>SubstVal i') elemTy - checkRho noHint x elemTy' - dTy <- case hoist b elemTy of - HoistSuccess elemTy' -> DictTy <$> dataDictType elemTy' - HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do - withFreshBinder noHint finTy \b' -> do - elemTy' <- applyRename (b@>binderName b') elemTy - dTy <- DictTy <$> dataDictType elemTy' - return $ Pi $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) - liftM Var $ emitHinted hint $ TabCon (dataDictHole dTy) tabTy xs' - --- Bool flag is just to tweak the reported error message -fromTabPiType :: EmitsBoth o => Bool -> CType o -> InfererM i o (TabPiType CoreIR o) -fromTabPiType _ (TabPi piTy) = return piTy -fromTabPiType expectPi ty = do - a <- freshType - b <- freshType - a' <- asIxType a - let piTy = nonDepTabPiType a' b - if expectPi then constrainTypesEq (TabPi piTy) ty - else constrainTypesEq ty (TabPi piTy) - return piTy - -fromProdType :: EmitsBoth o => Int -> CType o -> InfererM i o [CType o] -fromProdType n (ProdTy ts) | length ts == n = return ts -fromProdType n ty = do - ts <- mapM (const $ freshType) (replicate n ()) - constrainTypesEq (ProdTy ts) ty - return ts - -fromDepPairType :: EmitsBoth o => CType o -> InfererM i o (DepPairType CoreIR o) -fromDepPairType (DepPairTy t) = return t -fromDepPairType ty = throw TypeErr $ "Expected a dependent pair, but got: " ++ pprint ty - -addEffects :: EmitsBoth o => EffectRow CoreIR o -> InfererM i o () -addEffects eff = do - allowed <- checkAllowedUnconditionally eff - unless allowed $ do - effsAllowed <- getAllowedEffects - eff' <- openEffectRow eff - constrainEq (Eff effsAllowed) (Eff eff') - -checkAllowedUnconditionally :: EffectRow CoreIR o -> InfererM i o Bool -checkAllowedUnconditionally Pure = return True -checkAllowedUnconditionally eff = do - eff' <- zonk eff - effAllowed <- getAllowedEffects >>= zonk - return $ case checkExtends effAllowed eff' of - Failure _ -> False - Success () -> True - -openEffectRow :: EmitsBoth o => EffectRow CoreIR o -> InfererM i o (EffectRow CoreIR o) -openEffectRow (EffectRow effs NoTail) = extendEffRow effs <$> freshEff -openEffectRow effRow = return effRow - -asIxType :: CType o -> InfererM i o (IxType CoreIR o) -asIxType ty = do - dictTy <- DictTy <$> ixDictType ty - ctx <- srcPosCtx <$> getErrCtx - return $ IxType ty $ IxDictAtom $ DictHole (AlwaysEqual ctx) dictTy Full -{-# SCC asIxType #-} + expectEq sid (binderType b) finTy + xs' <- forM (enumerate xs) \(i, x) -> do + let i' = toAtom (NewtypeCon (FinCon (NatVal n)) (NatVal $ fromIntegral i)) :: CAtom o + elemTy' <- applySubst (b@>SubstVal i') elemTy + topDown elemTy' x + dTy <- case hoist b elemTy of + HoistSuccess elemTy' -> return $ toType $ DataDictType elemTy' + HoistFailure _ -> ignoreExcept <$> liftEnvReaderT do + withFreshBinder noHint finTy \b' -> do + elemTy' <- applyRename (b@>binderName b') elemTy + let dTy = toType $ DataDictType elemTy' + return $ toType $ CorePiType ImplicitApp [Inferred Nothing Unify] (UnaryNest b') (EffTy Pure dTy) + Just dataDict <- toMaybeDict <$> trySynthTerm sid dTy Full + emit $ TabCon (Just $ WhenIRE dataDict) (TyCon (TabPi tabTy)) xs' + +addEffects :: SrcId -> EffectRow CoreIR o -> InfererM i o () +addEffects _ Pure = return () +addEffects sid eff = do + effsAllowed <- infEffects <$> getInfState + case checkExtends effsAllowed eff of + Success () -> return () + Failure _ -> expectEq sid (Eff effsAllowed) (Eff eff) + +getIxDict :: SrcId -> CType o -> InfererM i o (IxDict CoreIR o) +getIxDict sid t = fromJust <$> toMaybeDict <$> trySynthTerm sid (toType $ IxDictType t) Full + +asIxType :: SrcId -> CType o -> InfererM i o (IxType CoreIR o) +asIxType sid ty = IxType ty <$> getIxDict sid ty -- === Solver === -newtype SolverSubst n = SolverSubst (M.Map (CAtomName n) (CAtom n)) - -instance Pretty (SolverSubst n) where - pretty (SolverSubst m) = pretty $ M.toList m - -class (CtxReader1 m, EnvReader m) => Solver (m::MonadKind1) where - zonk :: (SubstE AtomSubstVal e, SinkableE e) => e n -> m n (e n) - extendSolverSubst :: CAtomName n -> CAtom n -> m n () - emitSolver :: EmitsInf n => SolverBinding n -> m n (CAtomVar n) - solveLocal :: (SinkableE e, HoistableE e, RenameE e) - => (forall l. (EmitsInf l, Ext n l, Distinct l) => m l (e l)) - -> m n (e n) - -type SolverOutMap = InfOutMap - -data SolverOutFrag (n::S) (l::S) = - SolverOutFrag (SolverEmissions n l) (Constraints l) -newtype Constraints n = Constraints (SnocList (CAtomName n, CAtom n)) - deriving (Monoid, Semigroup) -type SolverEmissions = RNest (BinderP (AtomNameC CoreIR) SolverBinding) - -instance GenericE Constraints where - type RepE Constraints = ListE (CAtomName `PairE` CAtom) - fromE (Constraints xs) = ListE [PairE x y | (x,y) <- toList xs] - {-# INLINE fromE #-} - toE (ListE xs) = Constraints $ toSnocList $ [(x,y) | PairE x y <- xs] - {-# INLINE toE #-} - -instance SinkableE Constraints -instance RenameE Constraints -instance HoistableE Constraints -instance Pretty (Constraints n) where - pretty (Constraints xs) = pretty $ unsnoc xs - -instance GenericB SolverOutFrag where - type RepB SolverOutFrag = PairB SolverEmissions (LiftB Constraints) - fromB (SolverOutFrag em subst) = PairB em (LiftB subst) - toB (PairB em (LiftB subst)) = SolverOutFrag em subst - -instance ProvesExt SolverOutFrag -instance RenameB SolverOutFrag -instance BindsNames SolverOutFrag -instance SinkableB SolverOutFrag - -instance OutFrag SolverOutFrag where - emptyOutFrag = SolverOutFrag REmpty mempty - catOutFrags (SolverOutFrag em ss) (SolverOutFrag em' ss') = - withExtEvidence em' $ - SolverOutFrag (em >>> em') (sink ss <> ss') - -instance ExtOutMap InfOutMap SolverOutFrag where - extendOutMap infOutMap outFrag = - extendOutMap infOutMap $ liftSolverOutFrag outFrag - -newtype SolverM (n::S) (a:: *) = - SolverM { runSolverM' :: InplaceT SolverOutMap SolverOutFrag SearcherM n a } - deriving (Functor, Applicative, Monad, MonadFail, Alternative, Searcher, - ScopeReader, Fallible, CtxReader) - -liftSolverM :: EnvReader m => SolverM n a -> m n (Except a) -liftSolverM cont = do - env <- unsafeGetEnv - Distinct <- getDistinct - return do - maybeResult <- runSearcherM $ runInplaceT (initInfOutMap env) $ - runSolverM' $ cont - case maybeResult of - Nothing -> throw TypeErr "No solution" - Just (_, result) -> return result -{-# INLINE liftSolverM #-} - -instance EnvReader SolverM where - unsafeGetEnv = SolverM do - InfOutMap env _ _ _ _ <- getOutMapInplaceT - return env - {-# INLINE unsafeGetEnv #-} - -newtype SolverEmission (n::S) (l::S) = SolverEmission (BinderP (AtomNameC CoreIR) SolverBinding n l) -instance ExtOutMap SolverOutMap SolverEmission where - extendOutMap env (SolverEmission e) = env `extendOutMap` toEnvFrag e -instance ExtOutFrag SolverOutFrag SolverEmission where - extendOutFrag (SolverOutFrag es substs) (SolverEmission e) = - withSubscopeDistinct e $ SolverOutFrag (RNest es e) (sink substs) - -instance Solver SolverM where - extendSolverSubst v ty = SolverM $ - void $ extendTrivialInplaceT $ - SolverOutFrag REmpty (singleConstraint v ty) - {-# INLINE extendSolverSubst #-} - - zonk e = SolverM do - Distinct <- getDistinct - solverOutMap <- getOutMapInplaceT - return $ zonkWithOutMap solverOutMap $ sink e - {-# INLINE zonk #-} - - emitSolver binding = do - v <- SolverM $ freshExtendSubInplaceT (getNameHint @String "?") \b -> - (SolverEmission (b:>binding), binderName b) - toAtomVar v - {-# INLINE emitSolver #-} - - solveLocal cont = SolverM do - results <- locallyMutableInplaceT (do - Distinct <- getDistinct - EmitsInf <- fabricateEmitsInfEvidenceM - runSolverM' cont) (\d e -> return $ Abs d e) - Abs (SolverOutFrag unsolvedInfNames _) result <- return results - case unsolvedInfNames of - REmpty -> return result - _ -> case hoist unsolvedInfNames result of - HoistSuccess result' -> return result' - HoistFailure vs -> - throw TypeErr $ "Ambiguous type variables: " ++ pprint vs - {-# INLINE solveLocal #-} - -instance Unifier SolverM - -freshInferenceName :: (EmitsInf n, Solver m) => InfVarDesc -> Kind CoreIR n -> m n (CAtomVar n) -freshInferenceName desc k = do - ctx <- srcPosCtx <$> getErrCtx - emitSolver $ InfVarBound k (ctx, desc) -{-# INLINE freshInferenceName #-} - -freshSkolemName :: (EmitsInf n, Solver m) => Kind CoreIR n -> m n (CAtomVar n) -freshSkolemName k = emitSolver $ SkolemBound k -{-# INLINE freshSkolemName #-} - -type Solver2 (m::MonadKind2) = forall i. Solver (m i) - -emptySolverSubst :: SolverSubst n -emptySolverSubst = SolverSubst mempty - -singleConstraint :: CAtomName n -> CAtom n -> Constraints n -singleConstraint v ty = Constraints $ toSnocList [(v, ty)] - -- TODO: put this pattern and friends in the Name library? Don't really want to -- have to think about `eqNameColorRep` just to implement a partial map. lookupSolverSubst :: forall c n. Color c => SolverSubst n -> Name c n -> AtomSubstVal c n @@ -2375,209 +1681,302 @@ lookupSolverSubst (SolverSubst m) name = Nothing -> Rename name Just (ColorsEqual :: ColorsEqual c (AtomNameC CoreIR))-> case M.lookup name m of Nothing -> Rename name - Just ty -> SubstVal ty - -applySolverSubstE :: (SubstE AtomSubstVal e, Distinct n) - => Env n -> SolverSubst n -> e n -> e n -applySolverSubstE env solverSubst@(SolverSubst m) e = - if M.null m then e else fmapNames env (lookupSolverSubst solverSubst) e - -zonkWithOutMap :: (SubstE AtomSubstVal e, Distinct n) - => InfOutMap n -> e n -> e n -zonkWithOutMap (InfOutMap bindings solverSubst _ _ _) e = - applySolverSubstE bindings solverSubst e - -liftSolverOutFrag :: Distinct l => SolverOutFrag n l -> InfOutFrag n l -liftSolverOutFrag (SolverOutFrag emissions subst) = - InfOutFrag (liftSolverEmissions emissions) mempty subst - -liftSolverEmissions :: Distinct l => SolverEmissions n l -> InfEmissions n l -liftSolverEmissions emissions = - fmapRNest (\(b:>emission) -> (b:>RightE emission)) emissions - -fmapRNest :: (forall ii ii'. b ii ii' -> b' ii ii') - -> RNest b i i' - -> RNest b' i i' -fmapRNest _ REmpty = REmpty -fmapRNest f (RNest rest b) = RNest (fmapRNest f rest) (f b) - -instance GenericE SolverSubst where - -- XXX: this is a bit sketchy because it's not actually bijective... - type RepE SolverSubst = ListE (PairE CAtomName CAtom) - fromE (SolverSubst m) = ListE $ map (uncurry PairE) $ M.toList m - {-# INLINE fromE #-} - toE (ListE pairs) = SolverSubst $ M.fromList $ map fromPairE pairs - {-# INLINE toE #-} - -instance SinkableE SolverSubst where -instance RenameE SolverSubst where -instance HoistableE SolverSubst - -constrainTypesEq :: EmitsInf o => CType o -> CType o -> InfererM i o () -constrainTypesEq t1 t2 = constrainEq (Type t1) (Type t2) -- TODO: use a type class instead? - -constrainEq :: EmitsInf o => CAtom o -> CAtom o -> InfererM i o () -constrainEq t1 t2 = do - t1' <- zonk t1 - t2' <- zonk t2 - msg <- liftEnvReaderM $ do + Just sol -> SubstVal sol + +applyConstraint :: Constraint o -> SolverM i o () +applyConstraint = \case + TypeConstraint sid t1 t2 -> constrainEq sid t1 t2 + EffectConstraint sid r1 r2' -> do + -- r1 shouldn't have inference variables. And we can't infer anything about + -- any inference variables in r2's explicit effects because we don't know + -- how they line up with r1's. So this is just about figuring out r2's tail. + r2 <- zonk r2' + let msg = DisallowedEffects (pprint r1) (pprint r2) + case checkExtends r1 r2 of + Success () -> return () + Failure _ -> searchFailureAsTypeErr sid msg do + EffectRow effs1 t1 <- return r1 + EffectRow effs2 (EffectRowTail v2) <- return r2 + guard =<< isUnificationName (atomVarName v2) + guard $ null (eSetToList $ effs2 `eSetDifference` effs1) + let extras1 = effs1 `eSetDifference` effs2 + extendSolution v2 (toAtom $ EffectRow extras1 t1) + +constrainEq :: ToAtom e CoreIR => SrcId -> e o -> e o -> SolverM i o () +constrainEq sid t1 t2 = do + t1' <- zonk $ toAtom t1 + t2' <- zonk $ toAtom t2 + msg <- liftEnvReaderM do ab <- renameForPrinting $ PairE t1' t2' return $ canonicalizeForPrinting ab \(Abs infVars (PairE t1Pretty t2Pretty)) -> - "Expected: " ++ pprint t1Pretty - ++ "\n Actual: " ++ pprint t2Pretty - ++ (case infVars of - Empty -> "" - _ -> "\n(Solving for: " ++ pprint (nestToList pprint infVars) ++ ")") - void $ addContext msg $ liftSolverMInf $ unify t1' t2' - -class (Alternative1 m, Searcher1 m, Fallible1 m, Solver m) => Unifier m - -class (AlphaEqE e, SinkableE e, SubstE AtomSubstVal e) => Unifiable (e::E) where - unifyZonked :: EmitsInf n => e n -> e n -> SolverM n () - -tryConstrainEq :: EmitsInf o => CAtom o -> CAtom o -> InfererM i o () -tryConstrainEq t1 t2 = do - constrainEq t1 t2 `catchErr` \errs -> case errs of - Errs [Err TypeErr _ _] -> return () - _ -> throwErrs errs - -unify :: (EmitsInf n, Unifiable e) => e n -> e n -> SolverM n () -unify e1 e2 = do - e1' <- zonk e1 - e2' <- zonk e2 - (unifyZonked e1' e2' throw TypeErr "") -{-# INLINE unify #-} -{-# SCC unify #-} + UnificationFailure (pprint t1Pretty) (pprint t2Pretty) (nestToList pprint infVars) + void $ searchFailureAsTypeErr sid msg $ unify t1' t2' + +searchFailureAsTypeErr :: ToErr e => SrcId -> e -> SolverM i n a -> SolverM i n a +searchFailureAsTypeErr sid msg cont = cont <|> throw sid msg +{-# INLINE searchFailureAsTypeErr #-} + +class AlphaEqE e => Unifiable (e::E) where + unify :: e n -> e n -> SolverM i n () + +instance Unifiable (Stuck CoreIR) where + unify s1 s2 = do + x1 <- zonkStuck s1 + x2 <- zonkStuck s2 + case (x1, x2) of + (Con c, Con c') -> unify c c' + (Stuck _ s, Stuck _ s') -> unifyStuckZonked s s' + (Stuck _ s, Con c) -> unifyStuckConZonked s c + (Con c, Stuck _ s) -> unifyStuckConZonked s c + +-- assumes both `CStuck` args are zonked +unifyStuckZonked :: CStuck n -> CStuck n -> SolverM i n () +unifyStuckZonked s1 s2 = do + x1 <- mkStuck s1 + x2 <- mkStuck s2 + case (s1, s2) of + (Var v1, Var v2) -> do + if atomVarName v1 == atomVarName v2 + then return () + else extendSolution v2 x1 <|> extendSolution v1 x2 + (_, Var v2) -> extendSolution v2 x1 + (Var v1, _) -> extendSolution v1 x2 + (_, _) -> unifyEq s1 s2 + +unifyStuckConZonked :: CStuck n -> Con CoreIR n -> SolverM i n () +unifyStuckConZonked s x = case s of + Var v -> extendSolution v (Con x) + _ -> empty + +unifyStuckCon :: CStuck n -> Con CoreIR n -> SolverM i n () +unifyStuckCon s con = do + x <- zonkStuck s + case x of + Stuck _ s' -> unifyStuckConZonked s' con + Con con' -> unify con' con instance Unifiable (Atom CoreIR) where - unifyZonked e1 e2 = confuseGHC >>= \_ -> case sameConstructor e1 e2 of - False -> case (e1, e2) of - (t, Var (AtomVar v _)) -> extendSolution v t - (Var (AtomVar v _), t) -> extendSolution v t - _ -> empty - True -> case (e1, e2) of - (Var (AtomVar v' _), Var (AtomVar v _)) -> - if v == v' then return () else extendSolution v e1 <|> extendSolution v' e2 - (Eff eff, Eff eff') -> unify eff eff' - (Type t, Type t') -> case (t, t') of - (Pi piTy, Pi piTy') -> unify piTy piTy' - (TabPi piTy, TabPi piTy') -> unifyTabPiType piTy piTy' - (TC con, TC con') -> do - GenericOpRep name ts xs [] <- return $ fromEGenericOpRep con - GenericOpRep name' ts' xs' [] <- return $ fromEGenericOpRep con' - guard $ name == name' && length ts == length ts' && length xs == length xs' - zipWithM_ unify (Type <$> ts) (Type <$> ts') - zipWithM_ unify xs xs' - (DictTy d, DictTy d') -> unify d d' - (NewtypeTyCon con, NewtypeTyCon con') -> unify con con' - _ -> unifyEq t t' - _ -> unifyEq e1 e2 + unify (Con c) (Con c') = unify c c' + unify (Stuck _ s) (Stuck _ s') = unify s s' + unify (Stuck _ s) (Con c) = unifyStuckCon s c + unify (Con c) (Stuck _ s) = unifyStuckCon s c + +-- TODO: do this directly rather than going via `CAtom` using `Type`. We just +-- need to deal with `Stuck`. +instance Unifiable (Type CoreIR) where + unify (TyCon c) (TyCon c') = unify c c' + unify (StuckTy _ s) (StuckTy _ s') = unify s s' + unify (StuckTy _ s) (TyCon c) = unifyStuckCon s (TyConAtom c) + unify (TyCon c) (StuckTy _ s) = unifyStuckCon s (TyConAtom c) + +instance Unifiable (Con CoreIR) where + unify e1 e2 = case e1 of + ( Lit x ) -> do + { Lit x' <- matchit; guard (x == x')} + ( ProdCon xs ) -> do + { ProdCon xs' <- matchit; unifyLists xs xs'} + ( SumCon ts i x ) -> do + { SumCon ts' i' x' <- matchit; unifyLists ts ts'; guard (i==i'); unify x x'} + ( DepPair t x y ) -> do + { DepPair t' x' y' <- matchit; unify t t'; unify x x'; unify y y'} + ( HeapVal ) -> do + { HeapVal <- matchit; return ()} + ( Eff eff ) -> do + { Eff eff' <- matchit; unify eff eff'} + ( Lam lam ) -> do + { Lam lam' <- matchit; unifyEq lam lam'} + ( NewtypeCon con x ) -> do + { NewtypeCon con' x' <- matchit; unifyEq con con'; unify x x'} + ( TyConAtom t ) -> do + { TyConAtom t' <- matchit; unify t t'} + ( DictConAtom d ) -> do + { DictConAtom d' <- matchit; unifyEq d d'} + where matchit = return e2 + +instance Unifiable (TyCon CoreIR) where + unify t1 t2 = case t1 of + ( BaseType b ) -> do + { BaseType b' <- matchit; guard $ b == b'} + ( HeapType ) -> do + { HeapType <- matchit; return () } + ( TypeKind ) -> do + { TypeKind <- matchit; return () } + ( Pi piTy ) -> do + { Pi piTy' <- matchit; unify piTy piTy'} + ( TabPi piTy) -> do + { TabPi piTy' <- matchit; unify piTy piTy'} + ( DictTy d ) -> do + { DictTy d' <- matchit; unify d d'} + ( NewtypeTyCon con ) -> do + { NewtypeTyCon con' <- matchit; unify con con'} + ( SumType ts ) -> do + { SumType ts' <- matchit; unifyLists ts ts'} + ( ProdType ts ) -> do + { ProdType ts' <- matchit; unifyLists ts ts'} + ( RefType h t ) -> do + { RefType h' t' <- matchit; unify h h'; unify t t'} + ( DepPairTy t ) -> do + { DepPairTy t' <- matchit; unify t t'} + where matchit = return t2 + +unifyLists :: Unifiable e => [e n] -> [e n] -> SolverM i n () +unifyLists [] [] = return () +unifyLists (x:xs) (y:ys) = unify x y >> unifyLists xs ys +unifyLists _ _ = empty instance Unifiable DictType where - unifyZonked (DictType _ c params) (DictType _ c' params') = - guard (c == c') >> zipWithM_ unify params params' - {-# INLINE unifyZonked #-} + unify d1 d2 = case d1 of + ( DictType _ c params )-> do + { DictType _ c' params' <- matchit; guard (c == c'); unifyLists params params'} + ( IxDictType t ) -> do + { IxDictType t' <- matchit; unify t t'} + ( DataDictType t ) -> do + { DataDictType t' <- matchit; unify t t'} + where matchit = return d2 + {-# INLINE unify #-} instance Unifiable NewtypeTyCon where - unifyZonked e1 e2 = case (e1, e2) of - (Nat, Nat) -> return () - (Fin n, Fin n') -> unify n n' - (EffectRowKind, EffectRowKind) -> return () - (UserADTType _ c params, UserADTType _ c' params') -> guard (c == c') >> unify params params' - _ -> empty + unify e1 e2 = case e1 of + ( Nat ) -> do + { Nat <- matchit; return ()} + ( Fin n ) -> do + { Fin n' <- matchit; unify n n'} + ( EffectRowKind ) -> do + { EffectRowKind <- matchit; return ()} + ( UserADTType _ c params ) -> do + { UserADTType _ c' params' <- matchit; guard (c == c') >> unify params params' } + where matchit = return e2 instance Unifiable (IxType CoreIR) where -- We ignore the dictionaries because we assume coherence - unifyZonked (IxType t _) (IxType t' _) = unifyZonked t t' - --- TODO: do this directly rather than going via `CAtom` using `Type`. We just --- need to deal with `TyVar`. -instance Unifiable (Type CoreIR) where - unifyZonked t t' = unifyZonked (Type t) (Type t') + unify (IxType t _) (IxType t' _) = unify t t' instance Unifiable TyConParams where -- We ignore the dictionaries because we assume coherence - unifyZonked ps ps' = zipWithM_ unify (ignoreSynthParams ps) (ignoreSynthParams ps') + unify ps ps' = zipWithM_ unify (ignoreSynthParams ps) (ignoreSynthParams ps') instance Unifiable (EffectRow CoreIR) where - unifyZonked x1 x2 = + unify x1 x2 = unifyDirect x1 x2 <|> unifyDirect x2 x1 <|> unifyZip x1 x2 where - unifyDirect :: EmitsInf n => EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM n () + unifyDirect :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () unifyDirect r@(EffectRow effs' mv') (EffectRow effs (EffectRowTail v)) | null (eSetToList effs) = case mv' of EffectRowTail v' | v == v' -> guard $ null $ eSetToList effs' - _ -> extendSolution (atomVarName v) (Eff r) + _ -> extendSolution v (Con $ Eff r) unifyDirect _ _ = empty {-# INLINE unifyDirect #-} - unifyZip :: EmitsInf n => EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM n () + unifyZip :: EffectRow CoreIR n -> EffectRow CoreIR n -> SolverM i n () unifyZip r1 r2 = case (r1, r2) of (EffectRow effs1 t1, EffectRow effs2 t2) | not (eSetNull effs1 || eSetNull effs2) -> do let extras1 = effs1 `eSetDifference` effs2 let extras2 = effs2 `eSetDifference` effs1 - newRow <- freshEff - unify (EffectRow mempty t1) (extendEffRow extras2 newRow) - unify (extendEffRow extras1 newRow) (EffectRow mempty t2) + void $ withFreshEff \newRow -> do + unify (EffectRow mempty (sink t1)) (extendEffRow (sink extras2) newRow) + unify (extendEffRow (sink extras1) newRow) (EffectRow mempty (sink t2)) + return UnitE _ -> unifyEq r1 r2 -unifyEq :: AlphaEqE e => e n -> e n -> SolverM n () +withFreshEff + :: Zonkable e + => (forall o'. DExt o o' => EffectRow CoreIR o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshEff cont = + withFreshUnificationVarNoEmits rootSrcId MiscInfVar EffKind \v -> do + cont $ EffectRow mempty $ EffectRowTail v +{-# INLINE withFreshEff #-} + +unifyEq :: AlphaEqE e => e n -> e n -> SolverM i n () unifyEq e1 e2 = guard =<< alphaEq e1 e2 {-# INLINE unifyEq #-} instance Unifiable CorePiType where - unifyZonked (CorePiType appExpl1 expls1 bsTop1 effTy1) + unify (CorePiType appExpl1 expls1 bsTop1 effTy1) (CorePiType appExpl2 expls2 bsTop2 effTy2) = do unless (appExpl1 == appExpl2) empty unless (expls1 == expls2) empty go (Abs bsTop1 effTy1) (Abs bsTop2 effTy2) where - go :: EmitsInf n - => Abs (Nest CBinder) (EffTy CoreIR) n + go :: Abs (Nest CBinder) (EffTy CoreIR) n -> Abs (Nest CBinder) (EffTy CoreIR) n - -> SolverM n () + -> SolverM i n () go (Abs Empty (EffTy e1 t1)) (Abs Empty (EffTy e2 t2)) = unify t1 t2 >> unify e1 e2 go (Abs (Nest (b1:>t1) bs1) rest1) (Abs (Nest (b2:>t2) bs2) rest2) = do unify t1 t2 - v <- freshSkolemName t1 - ab1 <- zonk =<< applySubst (b1@>SubstVal (Var v)) (Abs bs1 rest1) - ab2 <- zonk =<< applySubst (b2@>SubstVal (Var v)) (Abs bs2 rest2) - go ab1 ab2 + void $ withFreshSkolemName t1 \v -> do + ab1 <- zonk =<< applyRename (b1@>atomVarName v) (Abs bs1 rest1) + ab2 <- zonk =<< applyRename (b2@>atomVarName v) (Abs bs2 rest2) + go ab1 ab2 + return UnitE go _ _ = empty -unifyTabPiType :: EmitsInf n => TabPiType CoreIR n -> TabPiType CoreIR n -> SolverM n () -unifyTabPiType (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = do - let ann1 = binderType b1 - let ann2 = binderType b2 - unify ann1 ann2 - v <- freshSkolemName ann1 - ty1' <- applySubst (b1@>SubstVal (Var v)) ty1 - ty2' <- applySubst (b2@>SubstVal (Var v)) ty2 - unify ty1' ty2' - -extendSolution :: CAtomName n -> CAtom n -> SolverM n () -extendSolution v t = - isInferenceName v >>= \case +instance Unifiable (TabPiType CoreIR) where + unify (TabPiType _ b1 ty1) (TabPiType _ b2 ty2) = + unify (Abs b1 ty1) (Abs b2 ty2) + +instance Unifiable (DepPairType CoreIR) where + unify (DepPairType expl1 b1 rhs1) (DepPairType expl2 b2 rhs2) = do + guard $ expl1 == expl2 + unify (Abs b1 rhs1) (Abs b2 rhs2) + +instance Unifiable (Abs CBinder CType) where + unify (Abs b1 ty1) (Abs b2 ty2) = do + let ann1 = binderType b1 + let ann2 = binderType b2 + unify ann1 ann2 + void $ withFreshSkolemName ann1 \v -> do + ty1' <- applyRename (b1@>atomVarName v) ty1 + ty2' <- applyRename (b2@>atomVarName v) ty2 + unify ty1' ty2' + return UnitE + +withFreshSkolemName + :: Zonkable e => Kind CoreIR o + -> (forall o'. DExt o o' => CAtomVar o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +withFreshSkolemName ty cont = diffStateT1 \s -> do + withFreshBinder "skol" (SkolemBound ty) \b -> do + (ans, diff) <- runDiffStateT1 (sink s) do + v <- toAtomVar $ binderName b + ans <- cont v >>= zonk + case hoist b ans of + HoistSuccess ans' -> return ans' + HoistFailure _ -> empty + case hoist b diff of + HoistSuccess diff' -> return (ans, diff') + HoistFailure _ -> empty +{-# INLINE withFreshSkolemName #-} + +extendSolution :: CAtomVar n -> CAtom n -> SolverM i n () +extendSolution (AtomVar v _) t = + isUnificationName v >>= \case True -> do - when (v `isFreeIn` t) $ throw TypeErr $ "Occurs check failure: " ++ pprint (v, t) + when (v `isFreeIn` t) solverFail -- occurs check -- When we unify under a pi binder we replace its occurrences with a -- skolem variable. We don't want to unify with terms containing these -- variables because that would mean inferring dependence, which is a can -- of worms. forM_ (freeAtomVarsList t) \fv -> - whenM (isSkolemName fv) $ throw TypeErr $ "Can't unify with skolem vars" - extendSolverSubst v t + whenM (isSkolemName fv) solverFail -- can't unify with skolems + addConstraint v t False -> empty -isInferenceName :: EnvReader m => CAtomName n -> m n Bool -isInferenceName v = lookupEnv v >>= \case - AtomNameBinding (SolverBound (InfVarBound _ _)) -> return True +isUnificationName :: EnvReader m => CAtomName n -> m n Bool +isUnificationName v = lookupEnv v >>= \case + AtomNameBinding (SolverBound (InfVarBound _)) -> return True _ -> return False -{-# INLINE isInferenceName #-} +{-# INLINE isUnificationName #-} + +isSolverName :: EnvReader m => CAtomName n -> m n Bool +isSolverName v = lookupEnv v >>= \case + AtomNameBinding (SolverBound _) -> return True + _ -> return False + isSkolemName :: EnvReader m => CAtomName n -> m n Bool isSkolemName v = lookupEnv v >>= \case @@ -2585,22 +1984,10 @@ isSkolemName v = lookupEnv v >>= \case _ -> return False {-# INLINE isSkolemName #-} -freshType :: (EmitsInf n, Solver m) => m n (CType n) -freshType = TyVar <$> freshInferenceName MiscInfVar TyKind -{-# INLINE freshType #-} - -freshAtom :: (EmitsInf n, Solver m) => Type CoreIR n -> m n (CAtom n) -freshAtom t = Var <$> freshInferenceName MiscInfVar t -{-# INLINE freshAtom #-} - -freshEff :: (EmitsInf n, Solver m) => m n (EffectRow CoreIR n) -freshEff = EffectRow mempty . EffectRowTail <$> freshInferenceName MiscInfVar EffKind -{-# INLINE freshEff #-} - -renameForPrinting :: (EnvReader m, HoistableE e, SinkableE e, RenameE e) - => e n -> m n (Abs (Nest (AtomNameBinder CoreIR)) e n) +renameForPrinting :: EnvReader m + => (PairE CAtom CAtom n) -> m n (Abs (Nest (AtomNameBinder CoreIR)) (PairE CAtom CAtom) n) renameForPrinting e = do - infVars <- filterM isInferenceVar $ freeAtomVarsList e + infVars <- filterM isSolverName $ freeAtomVarsList e let ab = abstractFreeVarsNoAnn infVars e let hints = take (length infVars) $ map getNameHint $ map (:[]) ['a'..'z'] ++ map show [(0::Int)..] @@ -2612,18 +1999,18 @@ renameForPrinting e = do e' <- applyRename (bsAbs@@>nestToNames bs') eAbs return $ Abs bs' e' --- === dictionary synthesis === +-- === builder and type querying helpers === -synthTopE :: (EnvReader m, Fallible1 m, DictSynthTraversable e) => e n -> m n (e n) -synthTopE block = do - (liftExcept =<<) $ liftDictSynthTraverserM $ dsTraverse block -{-# SCC synthTopE #-} +makeStructRepVal :: (Fallible1 m, EnvReader m) => TyConName n -> [CAtom n] -> m n (CAtom n) +makeStructRepVal tyConName args = do + TyConDef _ _ _ (StructFields fields) <- lookupTyCon tyConName + case fields of + [_] -> case args of + [arg] -> return arg + _ -> error "wrong number of args" + _ -> return $ Con $ ProdCon args -synthTyConDef :: (EnvReader m, Fallible1 m) => TyConDef n -> m n (TyConDef n) -synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynthTraverserM do - dsTraverseExplBinders (snd <$> roleExpls) bs \bs' -> - TyConDef sn roleExpls bs' <$> dsTraverse body -{-# SCC synthTyConDef #-} +-- === dictionary synthesis === -- Given a simplified dict (an Atom of type `DictTy _` in the -- post-simplification IR), and a requested, more general, dict type, generalize @@ -2632,133 +2019,89 @@ synthTyConDef (TyConDef sn roleExpls bs body) = (liftExcept =<<) $ liftDictSynth -- valid to implement `generalizeDict` by re-synthesizing the whole dictionary, -- but we know that the derivation tree has to be the same, so we take the -- shortcut of just generalizing the data parameters. -generalizeDict :: (EnvReader m) => CType n -> Dict n -> m n (Dict n) +generalizeDict :: EnvReader m => CType n -> CDict n -> m n (CDict n) generalizeDict ty dict = do - result <- liftSolverM $ solveLocal $ generalizeDictAndUnify (sink ty) (sink dict) + result <- liftEnvReaderM $ liftM fst $ liftInfererMPure $ generalizeDictRec ty dict case result of Failure e -> error $ "Failed to generalize " ++ pprint dict - ++ " to " ++ pprint ty ++ " because " ++ pprint e + ++ " to " ++ show ty ++ " because " ++ pprint e Success ans -> return ans -generalizeDictAndUnify :: EmitsInf n => CType n -> Dict n -> SolverM n (Dict n) -generalizeDictAndUnify ty dict = do - dict' <- generalizeDictRec dict - dictTy <- return $ getType dict' - unify ty dictTy - zonk dict' - -generalizeDictRec :: EmitsInf n => Dict n -> SolverM n (Dict n) -generalizeDictRec dict = do - -- TODO: we should be able to avoid the normalization here . We only need it - -- because we sometimes end up with superclass projections. But they shouldn't - -- really be allowed to occur in the post-simplification IR. - DictCon _ dict' <- cheapNormalize dict - mkDictAtom =<< case dict' of - InstanceDict instanceName args -> do - InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName - args' <- generalizeInstanceArgs roleExpls bs args - return $ InstanceDict instanceName args' - IxFin _ -> IxFin <$> Var <$> freshInferenceName MiscInfVar NatTy - InstantiatedGiven _ _ -> notSimplifiedDict - SuperclassProj _ _ -> notSimplifiedDict - DataData ty -> DataData <$> TyVar <$> freshInferenceName MiscInfVar ty - where notSimplifiedDict = error $ "Not a simplified dict: " ++ pprint dict - -generalizeInstanceArgs :: EmitsInf n => [RoleExpl] -> Nest CBinder n l -> [CAtom n] -> SolverM n [CAtom n] -generalizeInstanceArgs [] Empty [] = return [] -generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) = do - arg' <- case role of - -- XXX: for `TypeParam` we can just emit a fresh inference name rather than - -- traversing the whole type like we do in `Generalize.hs`. The reason is - -- that it's valid to implement `generalizeDict` by synthesizing an entirely - -- fresh dictionary, and if we were to do that, we would infer this type - -- parameter exactly as we do here, using inference. - TypeParam -> Var <$> freshInferenceName MiscInfVar TyKind - DictParam -> generalizeDictAndUnify ty arg - DataParam -> Var <$> freshInferenceName MiscInfVar ty - Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) - args' <- generalizeInstanceArgs expls bs' args - return $ arg':args' -generalizeInstanceArgs _ _ _ = error "zip error" - -synthInstanceDefAndAddSynthCandidate - :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceName n) -synthInstanceDefAndAddSynthCandidate def@(InstanceDef className expls bs params (InstanceBody superclasses _)) = do - let emptyDef = InstanceDef className expls bs params $ InstanceBody superclasses [] - instanceName <- emitInstanceDef emptyDef - addInstanceSynthCandidate className instanceName - synthInstanceDefRec instanceName def - return instanceName +generalizeDictRec :: CType n -> CDict n -> InfererM i n (CDict n) +generalizeDictRec targetTy (DictCon dict) = case dict of + InstanceDict _ instanceName args -> do + InstanceDef _ roleExpls bs _ _ <- lookupInstanceDef instanceName + liftSolverM $ generalizeInstanceArgs roleExpls bs args \args' -> do + d <- mkInstanceDict (sink instanceName) args' + -- We use rootSrcId here because we only call this after type inference so + -- precise source info isn't needed. + constrainEq rootSrcId (sink $ toAtom targetTy) (toAtom $ getType d) + return d + IxFin _ -> do + TyCon (DictTy (IxDictType (TyCon (NewtypeTyCon (Fin n))))) <- return targetTy + return $ DictCon $ IxFin n + DataData _ -> do + TyCon (DictTy (DataDictType t')) <- return targetTy + return $ DictCon $ DataData t' + IxRawFin _ -> error "not a simplified dict" +generalizeDictRec _ _ = error "not a simplified dict" + +generalizeInstanceArgs + :: Zonkable e => [RoleExpl] -> Nest CBinder o any -> [CAtom o] + -> (forall o'. DExt o o' => [CAtom o'] -> SolverM i o' (e o')) + -> SolverM i o (e o) +generalizeInstanceArgs [] Empty [] cont = withDistinct $ cont [] +generalizeInstanceArgs ((role,_):expls) (Nest (b:>ty) bs) (arg:args) cont = do + generalizeInstanceArg role ty arg \arg' -> do + Abs bs' UnitE <- applySubst (b@>SubstVal arg') (Abs bs UnitE) + generalizeInstanceArgs expls bs' (sink <$> args) \args' -> + cont $ sink arg' : args' +generalizeInstanceArgs _ _ _ _ = error "zip error" + +generalizeInstanceArg + :: Zonkable e => ParamRole -> CType o -> CAtom o + -> (forall o'. DExt o o' => CAtom o' -> SolverM i o' (e o')) + -> SolverM i o (e o) +generalizeInstanceArg role ty arg cont = case role of + -- XXX: for `TypeParam` we can just emit a fresh inference name rather than + -- traversing the whole type like we do in `Generalize.hs`. The reason is + -- that it's valid to implement `generalizeDict` by synthesizing an entirely + -- fresh dictionary, and if we were to do that, we would infer this type + -- parameter exactly as we do here, using inference. + TypeParam -> withFreshUnificationVarNoEmits rootSrcId MiscInfVar TyKind \v -> cont $ toAtom v + DictParam -> withFreshDictVarNoEmits ty ( + \ty' -> case toMaybeDict (sink arg) of + Just d -> liftM toAtom $ lift11 $ generalizeDictRec ty' d + _ -> error "not a dict") cont + DataParam -> withFreshUnificationVarNoEmits rootSrcId MiscInfVar ty \v -> cont $ toAtom v emitInstanceDef :: (Mut n, TopBuilder m) => InstanceDef n -> m n (Name InstanceNameC n) emitInstanceDef instanceDef@(InstanceDef className _ _ _ _) = do ty <- getInstanceType instanceDef emitBinding (getNameHint className) $ InstanceBinding instanceDef ty -type InstanceDefAbsBodyT = - ((ListE CAtom) `PairE` (ListE CAtom) `PairE` (ListE CAtom) `PairE` (ListE CAtom)) - -pattern InstanceDefAbsBody :: [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] - -> InstanceDefAbsBodyT n -pattern InstanceDefAbsBody params superclasses doneMethods todoMethods = - ListE params `PairE` (ListE superclasses) `PairE` (ListE doneMethods) `PairE` (ListE todoMethods) - -type InstanceDefAbsT n = ([RoleExpl], Abs (Nest CBinder) InstanceDefAbsBodyT n) - -pattern InstanceDefAbs :: [RoleExpl] -> Nest CBinder h n -> [CAtom n] -> [CAtom n] -> [CAtom n] -> [CAtom n] - -> InstanceDefAbsT h -pattern InstanceDefAbs expls bs params superclasses doneMethods todoMethods = - (expls, Abs bs (InstanceDefAbsBody params superclasses doneMethods todoMethods)) - -synthInstanceDefRec - :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) => InstanceName n -> InstanceDef n -> m n () -synthInstanceDefRec instanceName def = do - InstanceDef className roleExplsTop bs params (InstanceBody superclasses methods) <- return def - let ab = InstanceDefAbs roleExplsTop bs params superclasses [] methods - recur ab className instanceName - where - recur :: (Mut n, TopBuilder m, EnvReader m, Fallible1 m) - => InstanceDefAbsT n -> ClassName n -> InstanceName n -> m n () - recur (InstanceDefAbs _ _ _ _ _ []) _ _ = return () - recur (roleExpls, ab) cname iname = do - (def', ab') <- liftExceptEnvReaderM $ refreshAbs ab - \bs' (InstanceDefAbsBody ps scs doneMethods (m:ms)) -> do - EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess (snd<$>roleExpls) bs' env - flip runReaderT (Distinct, env') $ runEnvReaderT' do - m' <- synthTopE m - let doneMethods' = doneMethods ++ [m'] - let ab' = InstanceDefAbs roleExpls bs' ps scs doneMethods' ms - let def' = InstanceDef cname roleExpls bs' ps $ InstanceBody scs doneMethods' - return (def', ab') - updateTopEnv $ UpdateInstanceDef iname def' - recur ab' cname iname - -synthInstanceDef - :: (EnvReader m, Fallible1 m) => InstanceDef n -> m n (InstanceDef n) -synthInstanceDef (InstanceDef className expls bs params body) = do - liftExceptEnvReaderM $ refreshAbs (Abs bs (ListE params `PairE` body)) - \bs' (ListE params' `PairE` InstanceBody superclasses methods) -> do - EnvReaderT $ ReaderT \(Distinct, env) -> do - let env' = extendSynthCandidatess (snd<$>expls) bs' env - flip runReaderT (Distinct, env') $ runEnvReaderT' do - methods' <- mapM synthTopE methods - return $ InstanceDef className expls bs' params' $ InstanceBody superclasses methods' - -- main entrypoint to dictionary synthesizer -trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (SynthAtom n) -trySynthTerm ty reqMethodAccess = do +trySynthTerm :: SrcId -> CType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) +trySynthTerm sid ty reqMethodAccess = do hasInferenceVars ty >>= \case - True -> throw TypeErr "Can't synthesize a dictionary for a type with inference vars" - False -> do - synthTy <- liftExcept $ typeAsSynthType ty - solutions <- liftSyntherM $ synthTerm synthTy reqMethodAccess - case solutions of - [] -> throw TypeErr $ "Couldn't synthesize a class dictionary for: " ++ pprint ty - [d] -> cheapNormalize d -- normalize to reduce code size - _ -> throw TypeErr $ "Multiple candidate class dictionaries for: " ++ pprint ty + True -> throw sid $ CantSynthInfVars $ pprint ty + False -> withVoidSubst do + synthTy <- liftExcept $ typeAsSynthType sid ty + synthTerm sid synthTy reqMethodAccess + <|> (throw sid $ CantSynthDict $ pprint ty) {-# SCC trySynthTerm #-} +hasInferenceVars :: (EnvReader m, HoistableE e) => e n -> m n Bool +hasInferenceVars e = liftEnvReaderM $ anyInferenceVars $ freeAtomVarsList e +{-# INLINE hasInferenceVars #-} + +anyInferenceVars :: [CAtomName n] -> EnvReaderM n Bool +anyInferenceVars = \case + [] -> return False + (v:vs) -> isSolverName v >>= \case + True -> return True + False -> anyInferenceVars vs + type SynthAtom = CAtom type SynthPiType n = ([Explicitness], Abs (Nest CBinder) DictType n) data SynthType n = @@ -2768,38 +2111,13 @@ data SynthType n = data Givens n = Givens { fromGivens :: HM.HashMap (EKey SynthType n) (SynthAtom n) } -class (Alternative1 m, Searcher1 m, EnvReader m, EnvExtender m) - => Synther m where - getGivens :: m n (Givens n) - withGivens :: Givens n -> m n a -> m n a - -newtype SyntherM (n::S) (a:: *) = SyntherM - { runSyntherM' :: OutReaderT Givens (EnvReaderT []) n a } - deriving ( Functor, Applicative, Monad, EnvReader, EnvExtender - , ScopeReader, MonadFail - , Alternative, Searcher, OutReader Givens) - -instance Synther SyntherM where - getGivens = askOutReader - {-# INLINE getGivens #-} - withGivens givens cont = localOutReader givens cont - {-# INLINE withGivens #-} - -liftSyntherM :: EnvReader m => SyntherM n a -> m n [a] -liftSyntherM cont = - liftEnvReaderT do - initGivens <- givensFromEnv - runOutReaderT initGivens $ runSyntherM' cont -{-# INLINE liftSyntherM #-} - -givensFromEnv :: EnvReader m => m n (Givens n) -givensFromEnv = do - env <- withEnv moduleEnv - givens <- mapM toAtomVar $ lambdaDicts $ envSynthCandidates env - getSuperclassClosure (Givens HM.empty) (Var <$> givens) -{-# SCC givensFromEnv #-} - -extendGivens :: Synther m => [SynthAtom n] -> m n a -> m n a +getGivens :: InfererM i o (Givens o) +getGivens = givens <$> getInfState + +withGivens :: Givens o -> InfererM i o a -> InfererM i o a +withGivens givens cont = withInfState (\s -> s { givens = givens }) cont + +extendGivens :: [SynthAtom o] -> InfererM i o a -> InfererM i o a extendGivens newGivens cont = do prevGivens <- getGivens finalGivens <- getSuperclassClosure prevGivens newGivens @@ -2807,14 +2125,15 @@ extendGivens newGivens cont = do {-# INLINE extendGivens #-} getSynthType :: SynthAtom n -> SynthType n -getSynthType x = ignoreExcept $ typeAsSynthType (getType x) +getSynthType x = ignoreExcept $ typeAsSynthType rootSrcId (getType x) {-# INLINE getSynthType #-} -typeAsSynthType :: CType n -> Except (SynthType n) -typeAsSynthType = \case - DictTy dictTy -> return $ SynthDictType dictTy - Pi (CorePiType ImplicitApp expls bs (EffTy Pure (DictTy d))) -> return $ SynthPiType (expls, Abs bs d) - ty -> Failure $ Errs [Err TypeErr mempty $ "Can't synthesize terms of type: " ++ pprint ty] +typeAsSynthType :: SrcId -> CType n -> Except (SynthType n) +typeAsSynthType sid = \case + TyCon (DictTy dictTy) -> return $ SynthDictType dictTy + TyCon (Pi (CorePiType ImplicitApp expls bs (EffTy Pure (TyCon (DictTy d))))) -> + return $ SynthPiType (expls, Abs bs d) + ty -> Failure $ toErr sid $ NotASynthType $ pprint ty {-# SCC typeAsSynthType #-} getSuperclassClosure :: EnvReader m => Givens n -> [SynthAtom n] -> m n (Givens n) @@ -2824,8 +2143,7 @@ getSuperclassClosure givens newGivens = do return $ getSuperclassClosurePure env givens newGivens {-# INLINE getSuperclassClosure #-} -getSuperclassClosurePure - :: Distinct n => Env n -> Givens n -> [SynthAtom n] -> Givens n +getSuperclassClosurePure :: Distinct n => Env n -> Givens n -> [SynthAtom n] -> Givens n getSuperclassClosurePure env givens newGivens = snd $ runState (runEnvReaderT env (mapM_ visitGiven newGivens)) givens where @@ -2854,25 +2172,27 @@ getSuperclassClosurePure env givens newGivens = superclasses <- case synthTy of SynthPiType _ -> return [] SynthDictType dTy -> getSuperclassTys dTy - forM (enumerate superclasses) \(i, ty) -> do - return $ DictCon ty $ SuperclassProj synthExpr i + forM (enumerate superclasses) \(i, _) -> do + reduceSuperclassProj i $ fromJust (toMaybeDict synthExpr) -synthTerm :: SynthType n -> RequiredMethodAccess -> SyntherM n (SynthAtom n) -synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of +synthTerm :: SrcId -> SynthType n -> RequiredMethodAccess -> InfererM i n (SynthAtom n) +synthTerm sid targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of SynthPiType (expls, ab) -> do - ab' <- withGivenBinders expls ab \bs targetTy' -> do - Abs bs <$> synthTerm (SynthDictType targetTy') reqMethodAccess - Abs bs synthExpr <- return ab' - liftM Lam $ coreLamExpr ImplicitApp expls $ Abs bs $ PairE Pure (WithoutDecls synthExpr) + ab' <- withFreshBindersInf expls ab \bs' targetTy' -> do + Abs bs' <$> synthTerm sid (SynthDictType targetTy') reqMethodAccess + Abs bs' synthExpr <- return ab' + let piTy = CorePiType ImplicitApp expls bs' (EffTy Pure (getType synthExpr)) + let lamExpr = LamExpr bs' (Atom synthExpr) + return $ toAtom $ Lam $ CoreLamExpr piTy lamExpr SynthDictType dictTy -> case dictTy of - DictType "Ix" _ [Type (NewtypeTyCon (Fin n))] -> return $ DictCon (DictTy dictTy) $ IxFin n - DictType "Data" _ [Type t] -> do - void (synthDictForData dictTy synthDictFromGiven dictTy) - return $ DictCon (DictTy dictTy) $ DataData t + IxDictType (TyCon (NewtypeTyCon (Fin n))) -> return $ toAtom $ IxFin n + DataDictType t -> do + void (synthDictForData sid dictTy <|> synthDictFromGiven sid dictTy) + return $ toAtom $ DataData t _ -> do - dict <- synthDictFromInstance dictTy synthDictFromGiven dictTy + dict <- synthDictFromInstance sid dictTy <|> synthDictFromGiven sid dictTy case dict of - DictCon _ (InstanceDict instanceName _) -> do + Con (DictConAtom (InstanceDict _ instanceName _)) -> do isReqMethodAccessAllowed <- reqMethodAccess `isMethodAccessAllowedBy` instanceName if isReqMethodAccessAllowed then return dict @@ -2880,124 +2200,96 @@ synthTerm targetTy reqMethodAccess = confuseGHC >>= \_ -> case targetTy of _ -> return dict {-# SCC synthTerm #-} -coreLamExpr :: EnvReader m => AppExplicitness - -> [Explicitness] -> Abs (Nest CBinder) (PairE (EffectRow CoreIR) CBlock) n - -> m n (CoreLamExpr n) -coreLamExpr appExpl expls ab = liftEnvReaderM do - refreshAbs ab \bs' (PairE effs' body') -> do - EffTy _ resultTy <- blockEffTy body' - return $ CoreLamExpr (CorePiType appExpl expls bs' (EffTy effs' resultTy)) (LamExpr bs' body') - -withGivenBinders - :: (SinkableE e, RenameE e) => [Explicitness] -> Abs (Nest CBinder) e n - -> (forall l. DExt n l => Nest CBinder n l -> e l -> SyntherM l a) - -> SyntherM n a -withGivenBinders explsTop (Abs bsTop e) contTop = - runSubstReaderT idSubst $ go explsTop bsTop \bsTop' -> do - e' <- renameM e - liftSubstReaderT $ contTop bsTop' e' - where - go :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> SubstReaderT Name SyntherM i' o' a) - -> SubstReaderT Name SyntherM i o a - go expls bs cont = case (expls, bs) of - ([], Empty) -> getDistinct >>= \Distinct -> cont Empty - (expl:explsRest, Nest b rest) -> do - argTy <- renameM $ binderType b - withFreshBinder (getNameHint b) argTy \b' -> do - givens <- case expl of - Inferred _ (Synth _) -> return [Var $ binderVar b'] - _ -> return [] - s <- getSubst - liftSubstReaderT $ extendGivens givens $ - runSubstReaderT (s <>> b@>binderName b') $ - go explsRest rest \rest' -> cont (Nest b' rest') - _ -> error "zip error" - isMethodAccessAllowedBy :: EnvReader m => RequiredMethodAccess -> InstanceName n -> m n Bool isMethodAccessAllowedBy access instanceName = do InstanceDef className _ _ _ (InstanceBody _ methods) <- lookupInstanceDef instanceName let numInstanceMethods = length methods - ClassDef _ _ _ _ _ _ methodTys <- lookupClassDef className + ClassDef _ _ _ _ _ _ _ methodTys <- lookupClassDef className let numClassMethods = length methodTys case access of Full -> return $ numClassMethods == numInstanceMethods Partial numReqMethods -> return $ numReqMethods <= numInstanceMethods -synthDictFromGiven :: DictType n -> SyntherM n (SynthAtom n) -synthDictFromGiven targetTy = do +synthDictFromGiven :: SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictFromGiven sid targetTy = do givens <- ((HM.elems . fromGivens) <$> getGivens) asum $ givens <&> \given -> do case getSynthType given of SynthDictType givenDictTy -> do guard =<< alphaEq targetTy givenDictTy return given - SynthPiType givenPiTy -> do - args <- instantiateSynthArgs targetTy givenPiTy - return $ DictCon (DictTy targetTy) $ InstantiatedGiven given args - -synthDictFromInstance :: DictType n -> SyntherM n (SynthAtom n) -synthDictFromInstance targetTy@(DictType _ targetClass _) = do - instances <- getInstanceDicts targetClass - asum $ instances <&> \candidate -> do - CorePiType _ expls bs (EffTy _ (DictTy candidateTy)) <- lookupInstanceTy candidate - args <- instantiateSynthArgs targetTy (expls, Abs bs candidateTy) - return $ DictCon (DictTy targetTy) $ InstanceDict candidate args - -instantiateSynthArgs :: DictType n -> SynthPiType n -> SyntherM n [CAtom n] -instantiateSynthArgs targetTop (explsTop, Abs bsTop resultTyTop) = do - ListE args <- (liftExceptAlt =<<) $ liftSolverM $ solveLocal do - args <- runSubstReaderT idSubst $ go (sink targetTop) explsTop (sink $ Abs bsTop resultTyTop) - zonk $ ListE args - forM args \case - DictHole _ argTy req -> liftExceptAlt (typeAsSynthType argTy) >>= flip synthTerm req - arg -> return arg - where - go :: EmitsInf o - => DictType o -> [Explicitness] -> Abs (Nest CBinder) DictType i - -> SubstReaderT AtomSubstVal SolverM i o [CAtom o] - go target allExpls (Abs bs proposed) = case (allExpls, bs) of - ([], Empty) -> do - proposed' <- substM proposed - liftSubstReaderT $ unify target proposed' - return [] - (expl:expls, Nest b rest) -> do - argTy <- substM $ binderType b - arg <- liftSubstReaderT case expl of - Explicit -> error "instances shouldn't have explicit args" - Inferred _ Unify -> Var <$> freshInferenceName MiscInfVar argTy - Inferred _ (Synth req) -> return $ DictHole (AlwaysEqual emptySrcPosCtx) argTy req - liftM (arg:) $ extendSubst (b@>SubstVal arg) $ go target expls (Abs rest proposed) - _ -> error "zip error" - -synthDictForData :: forall n. DictType n -> SyntherM n (SynthAtom n) -synthDictForData dictTy@(DictType "Data" dName [Type ty]) = case ty of + SynthPiType givenPiTy -> typeErrAsSearchFailure do + args <- instantiateSynthArgs sid targetTy givenPiTy + reduceInstantiateGiven given args + +synthDictFromInstance :: SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictFromInstance sid targetTy = do + instances <- getInstanceDicts targetTy + asum $ instances <&> \candidate -> typeErrAsSearchFailure do + CorePiType _ expls bs (EffTy _ (TyCon (DictTy candidateTy))) <- lookupInstanceTy candidate + args <- instantiateSynthArgs sid targetTy (expls, Abs bs candidateTy) + return $ toAtom $ InstanceDict (toType targetTy) candidate args + +getInstanceDicts :: EnvReader m => DictType n -> m n [InstanceName n] +getInstanceDicts dictTy = do + env <- withEnv (envSynthCandidates . moduleEnv) + case dictTy of + DictType _ name _ -> return $ M.findWithDefault [] name $ instanceDicts env + IxDictType _ -> return $ ixInstances env + DataDictType _ -> return [] + +addInstanceSynthCandidate :: TopBuilder m => ClassName n -> Maybe BuiltinClassName -> InstanceName n -> m n () +addInstanceSynthCandidate className maybeBuiltin instanceName = do + sc <- return case maybeBuiltin of + Nothing -> mempty {instanceDicts = M.singleton className [instanceName] } + Just Ix -> mempty {ixInstances = [instanceName]} + Just Data -> mempty + emitLocalModuleEnv $ mempty {envSynthCandidates = sc} + +instantiateSynthArgs :: SrcId -> DictType n -> SynthPiType n -> InfererM i n [CAtom n] +instantiateSynthArgs sid target (expls, synthPiTy) = do + liftM fromListE $ withReducibleEmissions sid CantReduceDict do + bsConstrained <- buildConstraints (sink synthPiTy) \_ resultTy -> do + return [TypeConstraint sid (TyCon $ DictTy $ sink target) (TyCon $ DictTy resultTy)] + ListE <$> inferMixedArgs sid "dict" expls bsConstrained emptyMixedArgs + +emptyMixedArgs :: MixedArgs (CAtom n) +emptyMixedArgs = ([], []) + +typeErrAsSearchFailure :: InfererM i n a -> InfererM i n a +typeErrAsSearchFailure cont = cont `catchErr` \case + TypeErr _ _ -> empty + e -> throwErr e + +synthDictForData :: forall i n. SrcId -> DictType n -> InfererM i n (SynthAtom n) +synthDictForData sid dictTy@(DataDictType ty) = case ty of -- TODO Deduplicate vs CheckType.checkDataLike - -- The "Var" case is different - TyVar _ -> synthDictFromGiven dictTy - TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l >> recurBinder (Abs b r) >> success - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType nt - recur ty' >> success - TC con -> case con of - BaseType _ -> success - ProdType as -> mapM_ recur as >> success - SumType cs -> mapM_ recur cs >> success - RefType _ _ -> success - HeapType -> success - _ -> notData - _ -> notData + -- The "Stuck" case is different + StuckTy _ _ -> synthDictFromGiven sid dictTy + TyCon con -> case con of + TabPi (TabPiType _ b eltTy) -> recurBinder (Abs b eltTy) >> success + DepPairTy (DepPairType _ b@(_:>l) r) -> do + recur l >> recurBinder (Abs b r) >> success + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType nt + recur ty' >> success + BaseType _ -> success + ProdType as -> mapM_ recur as >> success + SumType cs -> mapM_ recur cs >> success + RefType _ _ -> success + HeapType -> success + _ -> notData where - recur ty' = synthDictForData $ DictType "Data" dName [Type ty'] - recurBinder :: (RenameB b, BindsEnv b) => Abs b CType n -> SyntherM n (SynthAtom n) - recurBinder bAbs = refreshAbs bAbs \b' ty'' -> do - ans <- synthDictForData $ DictType "Data" (sink dName) [Type ty''] - return $ ignoreHoistFailure $ hoist b' ans + recur ty' = synthDictForData sid $ DataDictType ty' + recurBinder :: Abs CBinder CType n -> InfererM i n (SynthAtom n) + recurBinder (Abs b body) = + withFreshBinderInf noHint Explicit (binderType b) \b' -> do + body' <- applyRename (b@>binderName b') body + ans <- synthDictForData sid $ DataDictType (toType body') + return $ ignoreHoistFailure $ hoist b' ans notData = empty - success = return $ DictCon (DictTy dictTy) $ DataData ty -synthDictForData dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy + success = return $ toAtom $ DataData ty +synthDictForData _ dictTy = error $ "Malformed Data dictTy " ++ pprint dictTy instance GenericE Givens where type RepE Givens = HashMapE (EKey SynthType) SynthAtom @@ -3008,188 +2300,29 @@ instance GenericE Givens where instance SinkableE Givens where --- === Dictionary synthesis traversal === - -liftDictSynthTraverserM - :: EnvReader m - => DictSynthTraverserM n n a - -> m n (Except a) -liftDictSynthTraverserM m = do - (ans, LiftE errs) <- liftM runHardFail $ liftBuilderT $ - runStateT1 (runSubstReaderT idSubst $ runDictSynthTraverserM m) (LiftE $ Errs []) - return $ case errs of - Errs [] -> Success ans - _ -> Failure errs - -newtype DictSynthTraverserM i o a = - DictSynthTraverserM - { runDictSynthTraverserM :: - SubstReaderT Name (StateT1 (LiftE Errs) (BuilderM CoreIR)) i o a} - deriving (MonadFail, Fallible, Functor, Applicative, Monad, ScopeReader, - EnvReader, EnvExtender, Builder CoreIR, SubstReader Name, - ScopableBuilder CoreIR, MonadState (LiftE Errs o)) - -instance NonAtomRenamer (DictSynthTraverserM i o) i o where renameN = renameM -instance Visitor (DictSynthTraverserM i o) CoreIR i o where - visitType = dsTraverse - visitAtom = dsTraverse - visitPi = visitPiDefault - visitLam = visitLamNoEmits -instance ExprVisitorNoEmits (DictSynthTraverserM i o) CoreIR i o where - visitExprNoEmits = visitGeneric - -class DictSynthTraversable (e::E) where - dsTraverse :: e i -> DictSynthTraverserM i o (e o) - -instance DictSynthTraversable (TopLam CoreIR) where - dsTraverse (TopLam d ty lam) = TopLam d <$> visitPiDefault ty <*> visitLamNoEmits lam - -instance DictSynthTraversable CAtom where - dsTraverse atom = case atom of - DictHole (AlwaysEqual ctx) ty access -> do - ty' <- cheapNormalize =<< dsTraverse ty - ans <- liftEnvReaderT $ addSrcContext ctx $ trySynthTerm ty' access - case ans of - Failure errs -> put (LiftE errs) >> renameM atom - Success d -> return d - Lam (CoreLamExpr piTy@(CorePiType _ expls _ _) (LamExpr bsLam (Abs decls result))) -> do - Pi piTy' <- dsTraverse $ Pi piTy - lam' <- dsTraverseExplBinders expls bsLam \bsLam' -> do - visitDeclsNoEmits decls \decls' -> do - LamExpr bsLam' <$> Abs decls' <$> dsTraverse result - return $ Lam $ CoreLamExpr piTy' lam' - Var _ -> renameM atom - SimpInCore _ -> renameM atom - ProjectElt _ _ _ -> renameM atom - _ -> visitAtomPartial atom - -instance DictSynthTraversable CType where - dsTraverse ty = case ty of - Pi (CorePiType appExpl expls bs (EffTy effs resultTy)) -> Pi <$> - dsTraverseExplBinders expls bs \bs' -> do - CorePiType appExpl expls bs' <$> (EffTy <$> renameM effs <*> dsTraverse resultTy) - TyVar _ -> renameM ty - ProjectEltTy _ _ _ -> renameM ty - _ -> visitTypePartial ty - -instance DictSynthTraversable DataConDefs where dsTraverse = visitGeneric - -dsTraverseExplBinders - :: [Explicitness] -> Nest CBinder i i' - -> (forall o'. DExt o o' => Nest CBinder o o' -> DictSynthTraverserM i' o' a) - -> DictSynthTraverserM i o a -dsTraverseExplBinders [] Empty cont = getDistinct >>= \Distinct -> cont Empty -dsTraverseExplBinders (expl:expls) (Nest b bs) cont = do - ty <- dsTraverse $ binderType b - withFreshBinder (getNameHint b) ty \b' -> do - let v = binderName b' - extendSynthCandidatesDict expl v $ extendRenamer (b@>v) do - dsTraverseExplBinders expls bs \bs' -> cont $ Nest b' bs' -dsTraverseExplBinders _ _ _ = error "zip error" - -extendSynthCandidatesDict :: Explicitness -> CAtomName n -> DictSynthTraverserM i n a -> DictSynthTraverserM i n a -extendSynthCandidatesDict c v cont = DictSynthTraverserM do - SubstReaderT $ ReaderT \env -> StateT1 \s -> BuilderT do - extendInplaceTLocal (extendSynthCandidates c v) $ runBuilderT' $ - runStateT1 (runSubstReaderT env $ runDictSynthTraverserM $ cont) s -{-# INLINE extendSynthCandidatesDict #-} - -- === Inference-specific builder patterns === --- The higher-order functions in Builder, like `buildLam` can't be easily used --- in inference because they don't allow for the emission of inference --- variables, which must be handled each time we leave a scope. In an earlier --- version we tried to put this logic in the implementation of InfererM's --- instance of Builder, but it forced us to overfit the Builder API to satisfy --- the needs of inference, like adding `SubstE AtomSubstVal e` constraints in --- various places. - type WithExpl = WithAttrB Explicitness type WithRoleExpl = WithAttrB RoleExpl -buildBlockInf - :: EmitsInf n - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (CAtom l)) - -> InfererM i n (CBlock n) -buildBlockInf cont = do - Abs decls (PairE result ty) <- buildDeclsInf do - ans <- cont - ty <- cheapNormalize $ getType ans - return $ PairE ans ty - let msg = "Block:" <> nest 1 (prettyBlock decls result) <> line - <> group ("Of type:" <> nest 2 (line <> pretty ty)) <> line - void $ liftHoistExcept' (docAsStr msg) $ hoist decls ty - return $ Abs decls result -{-# INLINE buildBlockInf #-} - buildBlockInfWithRecon - :: (EmitsInf n, RenameE e, HoistableE e, SinkableE e) - => (forall l. (EmitsBoth l, DExt n l) => InfererM i l (e l)) - -> InfererM i n (PairE CBlock (ReconAbs CoreIR e) n) + :: HasNamesE e + => (forall l. (Emits l, DExt n l) => InfererM i l (e l)) + -> InfererM i n (PairE CExpr (ReconAbs CoreIR e) n) buildBlockInfWithRecon cont = do - ab <- buildDeclsInfUnzonked cont - (block, recon) <- refreshAbs ab \decls result -> do + ab <- buildScoped cont + (block, recon) <- liftEnvReaderM $ refreshAbs ab \decls result -> do (newResult, recon) <- telescopicCapture decls result return (Abs decls newResult, recon) - return $ PairE block recon + block' <- mkBlock block + return $ PairE block' recon {-# INLINE buildBlockInfWithRecon #-} -buildTabPiInf - :: EmitsInf n - => NameHint -> IxType CoreIR n - -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) - -> InfererM i n (TabPiType CoreIR n) -buildTabPiInf hint (IxType t d) body = do - Abs b resultTy <- buildAbsInf hint Explicit t \v -> withoutEffects $ body v - return $ TabPiType d b resultTy - -buildDepPairTyInf - :: EmitsInf n - => NameHint -> DepPairExplicitness -> CType n - -> (forall l. (EmitsInf l, Ext n l) => CAtomVar l -> InfererM i l (CType l)) - -> InfererM i n (DepPairType CoreIR n) -buildDepPairTyInf hint expl ty body = do - Abs b resultTy <- buildAbsInf hint Explicit ty body - return $ DepPairType expl b resultTy - -buildAltInf - :: EmitsInf n - => CType n - -> (forall l. (EmitsBoth l, Ext n l) => CAtomVar l -> InfererM i l (CAtom l)) - -> InfererM i n (Alt CoreIR n) -buildAltInf ty body = do - buildAbsInf noHint Explicit ty \v -> - buildBlockInf do - Distinct <- getDistinct - body $ sink v - --- === EmitsInf predicate === - -type EmitsBoth n = (EmitsInf n, Emits n) - -class Mut n => EmitsInf (n::S) -data EmitsInfEvidence (n::S) where - EmitsInf :: EmitsInf n => EmitsInfEvidence n -instance EmitsInf UnsafeS - -fabricateEmitsInfEvidence :: forall n. EmitsInfEvidence n -fabricateEmitsInfEvidence = withFabricatedEmitsInf @n EmitsInf - -fabricateEmitsInfEvidenceM :: forall m n. Monad1 m => m n (EmitsInfEvidence n) -fabricateEmitsInfEvidenceM = return fabricateEmitsInfEvidence - -withFabricatedEmitsInf :: forall n a. (EmitsInf n => a) -> a -withFabricatedEmitsInf cont = fromWrapWithEmitsInf - ( TrulyUnsafe.unsafeCoerce ( WrapWithEmitsInf cont :: WrapWithEmitsInf n a - ) :: WrapWithEmitsInf UnsafeS a) -newtype WrapWithEmitsInf n r = - WrapWithEmitsInf { fromWrapWithEmitsInf :: EmitsInf n => r } - -- === IFunType === asFFIFunType :: EnvReader m => CType n -> m n (Maybe (IFunType, CorePiType n)) asFFIFunType ty = return do - Pi piTy <- return ty + TyCon (Pi piTy) <- return ty impTy <- checkFFIFunTypeM piTy return (impTy, piTy) @@ -3212,7 +2345,7 @@ checkFFIFunTypeM _ = error "expected at least one argument" checkScalar :: (IRRep r, Fallible m) => Type r n -> m BaseType checkScalar (BaseTy ty) = return ty -checkScalar ty = throw TypeErr $ pprint ty +checkScalar ty = throw rootSrcId $ FFIArgTyNotScalar $ pprint ty checkScalarOrPairType :: (IRRep r, Fallible m) => Type r n -> m [BaseType] checkScalarOrPairType (PairTy a b) = do @@ -3220,10 +2353,53 @@ checkScalarOrPairType (PairTy a b) = do tys2 <- checkScalarOrPairType b return $ tys1 ++ tys2 checkScalarOrPairType (BaseTy ty) = return [ty] -checkScalarOrPairType ty = throw TypeErr $ pprint ty +checkScalarOrPairType ty = throw rootSrcId $ FFIResultTyErr $ pprint ty -- === instances === +instance DiffStateE SolverSubst SolverDiff where + updateDiffStateE :: forall n. Distinct n => Env n -> SolverSubst n -> SolverDiff n -> SolverSubst n + updateDiffStateE _ initState (SolverDiff (RListE diffs)) = foldl update' initState (unsnoc diffs) + where + update' :: Distinct n => SolverSubst n -> Solution n -> SolverSubst n + update' (SolverSubst subst) (PairE v x) = SolverSubst $ M.insert v x subst + +instance SinkableE InfState where sinkingProofE _ = todoSinkableProof + +instance GenericE SigmaAtom where + type RepE SigmaAtom = EitherE3 (LiftE (Maybe SourceName) `PairE` CAtom) + (LiftE SourceName `PairE` CType `PairE` UVar) + (CType `PairE` CAtom `PairE` ListE CAtom) + fromE = \case + SigmaAtom x y -> Case0 $ LiftE x `PairE` y + SigmaUVar x y z -> Case1 $ LiftE x `PairE` y `PairE` z + SigmaPartialApp x y z -> Case2 $ x `PairE` y `PairE` ListE z + {-# INLINE fromE #-} + + toE = \case + Case0 (LiftE x `PairE` y) -> SigmaAtom x y + Case1 (LiftE x `PairE` y `PairE` z) -> SigmaUVar x y z + Case2 (x `PairE` y `PairE` ListE z) -> SigmaPartialApp x y z + _ -> error "impossible" + {-# INLINE toE #-} + +instance RenameE SigmaAtom +instance HoistableE SigmaAtom +instance SinkableE SigmaAtom + +instance SubstE AtomSubstVal SigmaAtom where + substE env (SigmaAtom sn x) = SigmaAtom sn $ substE env x + substE env (SigmaUVar sn ty uvar) = case uvar of + UAtomVar v -> substE env $ SigmaAtom (Just sn) $ toAtom (AtomVar v ty) + UTyConVar v -> SigmaUVar sn ty' $ UTyConVar $ substE env v + UDataConVar v -> SigmaUVar sn ty' $ UDataConVar $ substE env v + UPunVar v -> SigmaUVar sn ty' $ UPunVar $ substE env v + UClassVar v -> SigmaUVar sn ty' $ UClassVar $ substE env v + UMethodVar v -> SigmaUVar sn ty' $ UMethodVar $ substE env v + where ty' = substE env ty + substE env (SigmaPartialApp ty f xs) = + SigmaPartialApp (substE env ty) (substE env f) (map (substE env) xs) + instance PrettyE e => Pretty (UDeclInferenceResult e l) where pretty = \case UDeclResultDone e -> pretty e @@ -3241,34 +2417,6 @@ instance (RenameE e, CheckableE CoreIR e) => CheckableE CoreIR (UDeclInferenceRe UDeclResultBindPattern hint block recon -> UDeclResultBindPattern hint <$> checkE block <*> renameM recon -- TODO: check recon -instance HasType CoreIR InfEmission where - getType = \case - LeftE (DeclBinding _ e) -> getType e - RightE b -> case b of - InfVarBound t _ -> t - SkolemBound t -> t - -instance (Monad m, ExtOutMap InfOutMap decls, OutFrag decls) - => EnvReader (InplaceT InfOutMap decls m) where - unsafeGetEnv = do - InfOutMap env _ _ _ _ <- getOutMapInplaceT - return env - -instance (Monad m, ExtOutMap InfOutMap decls, OutFrag decls) - => EnvExtender (InplaceT InfOutMap decls m) where - refreshAbs ab cont = UnsafeMakeInplaceT \env decls -> - refreshAbsPure (toScope env) ab \_ b e -> do - let subenv = extendOutMap env $ toEnvFrag b - (ans, d, _) <- unsafeRunInplaceT (cont b e) subenv emptyOutFrag - case fabricateDistinctEvidence @UnsafeS of - Distinct -> do - let env' = extendOutMap (unsafeCoerceE env) d - return (ans, catOutFrags decls d, env') - {-# INLINE refreshAbs #-} - -instance BindsEnv InfOutFrag where - toEnvFrag (InfOutFrag frag _ _) = toEnvFrag frag - instance GenericE SynthType where type RepE SynthType = EitherE2 DictType (PairE (LiftE [Explicitness]) (Abs (Nest CBinder) DictType)) fromE (SynthDictType d) = Case0 d @@ -3284,6 +2432,76 @@ instance HoistableE SynthType instance RenameE SynthType instance SubstE AtomSubstVal SynthType +instance GenericE Constraint where + type RepE Constraint = PairE + (LiftE SrcId) + (EitherE + (PairE CType CType) + (PairE (EffectRow CoreIR) (EffectRow CoreIR))) + fromE (TypeConstraint sid t1 t2) = LiftE sid `PairE` LeftE (PairE t1 t2) + fromE (EffectConstraint sid e1 e2) = LiftE sid `PairE` RightE (PairE e1 e2) + {-# INLINE fromE #-} + toE (LiftE sid `PairE` LeftE (PairE t1 t2)) = TypeConstraint sid t1 t2 + toE (LiftE sid `PairE` RightE (PairE e1 e2)) = EffectConstraint sid e1 e2 + {-# INLINE toE #-} + +instance SinkableE Constraint +instance HoistableE Constraint +instance (SubstE AtomSubstVal) Constraint + +instance GenericE RequiredTy where + type RepE RequiredTy = EitherE CType UnitE + fromE (Check ty) = LeftE ty + fromE Infer = RightE UnitE + {-# INLINE fromE #-} + toE (LeftE ty) = Check ty + toE (RightE UnitE) = Infer + {-# INLINE toE #-} + +instance SinkableE RequiredTy +instance HoistableE RequiredTy +instance AlphaEqE RequiredTy +instance RenameE RequiredTy + +instance GenericE PartialType where + type RepE PartialType = EitherE PartialPiType CType + fromE (PartialType ty) = LeftE ty + fromE (FullType ty) = RightE ty + {-# INLINE fromE #-} + toE (LeftE ty) = PartialType ty + toE (RightE ty) = FullType ty + {-# INLINE toE #-} + +instance SinkableE PartialType +instance HoistableE PartialType +instance AlphaEqE PartialType +instance RenameE PartialType + +instance GenericE SolverSubst where + -- XXX: this is a bit sketchy because it's not actually bijective... + type RepE SolverSubst = ListE (PairE CAtomName CAtom) + fromE (SolverSubst m) = ListE $ map (uncurry PairE) $ M.toList m + {-# INLINE fromE #-} + toE (ListE pairs) = SolverSubst $ M.fromList $ map fromPairE pairs + {-# INLINE toE #-} + +instance SinkableE SolverSubst where +instance RenameE SolverSubst where +instance HoistableE SolverSubst + +instance GenericE PartialPiType where + type RepE PartialPiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) + (EffectRow CoreIR `PairE` RequiredTy) + fromE (PartialPiType ex exs b eff ty) = LiftE (ex, exs) `PairE` Abs b (PairE eff ty) + {-# INLINE fromE #-} + toE (LiftE (ex, exs) `PairE` Abs b (PairE eff ty)) = PartialPiType ex exs b eff ty + {-# INLINE toE #-} + +instance SinkableE PartialPiType +instance HoistableE PartialPiType +instance AlphaEqE PartialPiType +instance RenameE PartialPiType + -- See Note [Confuse GHC] from Simplify.hs confuseGHC :: EnvReader m => m n (DistinctEvidence n) confuseGHC = getDistinct diff --git a/src/lib/Inference.hs-boot b/src/lib/Inference.hs-boot deleted file mode 100644 index a8f219389..000000000 --- a/src/lib/Inference.hs-boot +++ /dev/null @@ -1,14 +0,0 @@ --- Copyright 2021 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Inference (trySynthTerm) where - -import Core -import Name -import Types.Core -import Types.Primitives (RequiredMethodAccess) - -trySynthTerm :: (Fallible1 m, EnvReader m) => CType n -> RequiredMethodAccess -> m n (CAtom n) diff --git a/src/lib/Inline.hs b/src/lib/Inline.hs index 36cfc1c8e..f72f24bef 100644 --- a/src/lib/Inline.hs +++ b/src/lib/Inline.hs @@ -6,8 +6,6 @@ module Inline (inlineBindings) where -import Data.List.NonEmpty qualified as NE - import Builder import Core import Err @@ -16,15 +14,15 @@ import IRVariants import Name import Subst import Occurrence hiding (Var) -import Optimize +import PeepholeOptimize import Types.Core import Types.Primitives +import Types.Top -- === External API === inlineBindings :: (EnvReader m) => STopLam n -> m n (STopLam n) -inlineBindings = liftLamExpr \(Abs decls ans) -> liftInlineM $ - buildScoped $ inlineDecls decls $ inline Stop ans +inlineBindings lam = liftLamExpr lam \body -> liftInlineM $ buildBlock $ inlineExpr Stop body {-# INLINE inlineBindings #-} {-# SCC inlineBindings #-} @@ -75,12 +73,6 @@ data SizePreservationInfo = | UsedMulti deriving (Eq, Show) -inlineDecls :: Emits o => Nest SDecl i i' -> InlineM i' o a -> InlineM i o a -inlineDecls decls cont = do - s <- inlineDeclsSubst decls - withSubst s cont -{-# INLINE inlineDecls #-} - inlineDeclsSubst :: Emits o => Nest SDecl i i' -> InlineM i o (Subst InlineSubstVal i' o) inlineDeclsSubst = \case Empty -> getSubst @@ -89,75 +81,37 @@ inlineDeclsSubst = \case s <- getSubst extendSubst (b @> SubstVal (SuspEx expr s)) $ inlineDeclsSubst rest else do - expr' <- inlineExpr Stop expr >>= (liftEnvReaderM . peepholeExpr) + expr' <- peepholeExpr <$> inlineExpr Stop expr -- If the inliner starts moving effectful expressions, it may become -- necessary to query the effects of the new expression here. let presInfo = resolveWorkConservation ann expr' - -- A subtlety from the Secrets paper. In Haskell, it is feasible to have - -- a binding whose occurrence information indicates multiple uses, but - -- which does a small, bounded amount of runtime work. GHC will inline - -- such a binding, but not into contexts where GHC knows that no further - -- optimizations are possible. The example given in the paper is - -- f = \x -> E - -- g = \ys -> map f ys - -- Inlining f here is useless because it's not applied, and mildly costly - -- because it causes the closure to be allocated at every call to g rather - -- than just once. - -- TODO If we want to track this subtlety, we should make room for it in - -- the SizePreservationInfo ADT (maybe rename it), maybe with a - -- OnceButDuplicatesBoundedWork constructor. Then only the true UsedOnce - -- would be inlined unconditionally here, and the - -- OnceButDuplicatesBoundedWork constructor could be inlined or not - -- depending on its usage context. (This would correspond to the case - -- OnceUnsafe with whnfOrBot == True in the Secrets paper.) + -- See NoteSecretsSubtlety if presInfo == UsedOnce then do let substVal = case expr' of - Atom (Var name') -> Rename $ atomVarName name' + Atom (Stuck _ (Var name')) -> Rename $ atomVarName name' _ -> SubstVal (DoneEx expr') extendSubst (b @> substVal) $ inlineDeclsSubst rest else do -- expr' can't be Atom (Var x) here name' <- emitDecl (getNameHint b) (dropOccInfo ann) expr' extendSubst (b @> Rename (atomVarName name')) do - -- TODO For now, this inliner does not do any conditional inlining. - -- In order to do it, we would need to augment the environment at this - -- point, associating name' to (expr', presInfo) so name' could be - -- inlined at use sites. - -- - -- Conditional inlining is different in Dex vs Haskell because Dex is - -- strict. To wit, once we have emitted the bidning for `expr'`, we - -- are committed to doing the work it represents unless it's inlined - -- _everywhere_. For example, - -- xs = - -- case of - -- Nothing -> xs -- ok to inline here - -- Just _ -> xs ... xs -- not ok here - -- If this were Haskell, it would be work-preserving for GHC to inline - -- `xs` into the `Nothing` arm, but in Dex it's not, unless we first - -- explicitly push the binding into the case like - -- case of - -- Nothing -> xs = ; xs - -- Just _ -> xs = ; xs ... xs - -- - -- That said, the Secrets paper says that GHC only conditionally - -- inlines zero-work bindings anyway (or, more precisely, "bounded - -- finite work" bindings). All the heuristics about whether to inline - -- at a particular site are about code size and not increasing it - -- overmuch. But, of course, inlining even zero-work bindings can - -- help runtime performance because it can unblock other optimizations - -- that otherwise could not occur across the binding. + -- See NoteConditionalInlining inlineDeclsSubst rest where dropOccInfo PlainLet = PlainLet + dropOccInfo LinearLet = LinearLet + dropOccInfo InlineLet = InlineLet dropOccInfo NoInlineLet = NoInlineLet dropOccInfo (OccInfoPure _) = PlainLet dropOccInfo (OccInfoImpure _) = PlainLet resolveWorkConservation PlainLet _ = NoInline -- No occurrence info, assume the worst + resolveWorkConservation LinearLet _ = NoInline + resolveWorkConservation InlineLet _ = NoInline resolveWorkConservation NoInlineLet _ = NoInline -- Quick hack to always unconditionally inline renames, until we get -- a better story about measuring the sizes of atoms and expressions. - resolveWorkConservation (OccInfoPure _) (Atom (Var _)) = UsedOnce + resolveWorkConservation (OccInfoPure _) (Atom (Stuck _ (Var _))) = UsedOnce resolveWorkConservation (OccInfoPure (UsageInfo s (ixDepth, d))) expr | d <= One = case ixDepthExpr expr >= ixDepth of True -> if s <= One then UsedOnce else UsedMulti @@ -214,12 +168,8 @@ inlineDeclsSubst = \case -- since their main purpose is to force inlining in the simplifier, and if -- one just stuck like this it has become equivalent to a `for` anyway. ixDepthExpr :: Expr SimpIR n -> Int - ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthBlock body + ixDepthExpr (PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr _ body))))) = 1 + ixDepthExpr body ixDepthExpr _ = 0 - ixDepthBlock :: Block SimpIR n -> Int - ixDepthBlock (exprBlock -> (Just expr)) = ixDepthExpr expr - ixDepthBlock (Abs Empty result) = ixDepthExpr $ Atom result - ixDepthBlock _ = 0 -- Should we decide to inline this binding wherever it appears, before we even -- know the expression? "Yes" only if we know it only occurs once, and in a @@ -227,7 +177,9 @@ inlineDeclsSubst = \case preInlineUnconditionally :: LetAnn -> Bool preInlineUnconditionally = \case PlainLet -> False -- "Missing occurrence annotation" + InlineLet -> True NoInlineLet -> False + LinearLet -> False OccInfoPure (UsageInfo s (0, d)) | s <= One && d <= One -> True OccInfoPure _ -> False OccInfoImpure _ -> False @@ -245,7 +197,7 @@ preInlineUnconditionally = \case -- instead of emitting the binding. data Context (from::E) (to::E) (o::S) where Stop :: Context e e o - TabAppCtx :: [SAtom i] -> Subst InlineSubstVal i o + TabAppCtx :: SAtom i -> Subst InlineSubstVal i o -> Context SExpr e o -> Context SExpr e o CaseCtx :: [SAlt i] -> SType i -> EffectRow SimpIR i -> Subst InlineSubstVal i o @@ -270,22 +222,35 @@ instance Emits o => Visitor (InlineM i o) SimpIR i o where inlineExpr :: Emits o => Context SExpr e o -> SExpr i -> InlineM i o (e o) inlineExpr ctx = \case Atom atom -> inlineAtom ctx atom - TabApp _ tbl ixs -> do + TabApp _ tbl ix -> do s <- getSubst - inlineAtom (TabAppCtx ixs s ctx) tbl + inlineAtom (TabAppCtx ix s ctx) tbl Case scrut alts (EffTy effs resultTy) -> do s <- getSubst inlineAtom (CaseCtx alts resultTy effs s ctx) scrut + Block _ (Abs decls ans) -> do + s <- inlineDeclsSubst decls + withSubst s $ inlineExpr ctx ans expr -> visitGeneric expr >>= reconstruct ctx inlineAtom :: Emits o => Context SExpr e o -> SAtom i -> InlineM i o (e o) inlineAtom ctx = \case + Stuck _ stuck -> inlineStuck ctx stuck + Con con -> (toExpr <$> visitGeneric con) >>= reconstruct ctx + +inlineStuck :: Emits o => Context SExpr e o -> SStuck i -> InlineM i o (e o) +inlineStuck ctx = \case Var name -> inlineName ctx name - ProjectElt _ i x -> do - let (idxs, v) = asNaryProj i x - ans <- normalizeNaryProj (NE.toList idxs) =<< inline Stop (Var v) + StuckProject i x -> do + ans <- proj i =<< emit =<< inlineStuck Stop x reconstruct ctx $ Atom ans - atom -> (Atom <$> visitAtomPartial atom) >>= reconstruct ctx + StuckTabApp _ _ -> error "not implemented" + PtrVar t p -> do + s <- mkStuck =<< (PtrVar t <$> substM p) + reconstruct ctx (toExpr s) + RepValAtom repVal -> do + s <- mkStuck =<< (RepValAtom <$> visitGeneric repVal) + reconstruct ctx (toExpr s) inlineName :: Emits o => Context SExpr e o -> SAtomVar i -> InlineM i o (e o) inlineName ctx name = @@ -300,7 +265,7 @@ inlineName ctx name = -- (expr', presInfo) | inline presInfo expr' ctx -> inline -- no info -> do not inline (as now) v <- toAtomVar name' - reconstruct ctx (Atom $ Var v) + reconstruct ctx (toExpr v) SubstVal (DoneEx expr) -> dropSubst $ inlineExpr ctx expr SubstVal (SuspEx expr s') -> withSubst s' $ inlineExpr ctx expr @@ -311,13 +276,13 @@ instance Inlinable SAtom where inline ctx a = inlineAtom (EmitToAtomCtx ctx) a instance Inlinable SType where - inline ctx ty = visitTypePartial ty >>= reconstruct ctx + inline ctx (TyCon ty) = (TyCon <$> visitGeneric ty) >>= reconstruct ctx instance Inlinable SLam where - inline ctx (LamExpr bs (Abs decls ans)) = do + inline ctx (LamExpr bs body) = do reconstruct ctx =<< withBinders bs \bs' -> do - (LamExpr bs' <$>) $ buildScoped $ - inlineDecls decls $ inline Stop ans + body' <- buildBlock $ inlineExpr Stop body + return $ LamExpr bs' body' withBinders :: Nest SBinder i i' @@ -336,76 +301,30 @@ instance Inlinable (PiType SimpIR) where effTy' <- buildScopedAssumeNoDecls $ inline Stop effTy return $ PiType bs' effTy' -inlineBlockEmits :: Emits o => Context SExpr e2 o -> SBlock i -> InlineM i o (e2 o) -inlineBlockEmits ctx (Abs decls ans) = do - inlineDecls decls $ inlineAtom ctx ans - -- Still using InlineM because we may call back into inlining, and we wish to -- retain our output binding environment. reconstruct :: Emits o => Context e1 e2 o -> e1 o -> InlineM i o (e2 o) reconstruct ctx e = case ctx of Stop -> return e - TabAppCtx ixs s ctx' -> withSubst s $ reconstructTabApp ctx' e ixs + TabAppCtx ix s ctx' -> withSubst s $ reconstructTabApp ctx' e ix CaseCtx alts resultTy effs s ctx' -> withSubst s $ reconstructCase ctx' e alts resultTy effs - EmitToAtomCtx ctx' -> emitExprToAtom e >>= reconstruct ctx' - EmitToNameCtx ctx' -> emit (Atom e) >>= reconstruct ctx' + EmitToAtomCtx ctx' -> emit e >>= reconstruct ctx' + EmitToNameCtx ctx' -> emitToVar e >>= reconstruct ctx' {-# INLINE reconstruct #-} reconstructTabApp :: Emits o - => Context SExpr e o -> SExpr o -> [SAtom i] -> InlineM i o (e o) -reconstructTabApp ctx expr [] = do - reconstruct ctx expr -reconstructTabApp ctx expr ixs = - case fromNaryForExpr (length ixs) expr of - Just (bsCount, LamExpr bs (Abs decls result)) -> do - let (ixsPref, ixsRest) = splitAt bsCount ixs - -- Note: There's a decision here. Is it ok to inline the atoms in - -- `ixsPref` into the body `decls`? If so, should we pre-process them and - -- carry them in `DoneEx`, or suspend them in `SuspEx`? (If not, we can - -- emit fresh bindings and use `Rename`.) We can't make this decision - -- properly without annotating the `for` binders with occurrence - -- information; even though `ixsPref` itself are atoms, we may be carrying - -- suspended inlining decisions that would want to make one an expression, - -- and thus force-inlining it may duplicate work. - -- - -- There remains a decision between just emitting bindings, or running - -- `mapM (inline $ EmitToAtomCtx Stop)` and inlining the resulting atoms. - -- In the work-heavy case where an element of `ixsPref` becomes an - -- expression after inlining, the result will be the same; but in the - -- work-light case where the element remains an atom, more inlining can - -- proceed. This decision only affects the runtime of the inliner and the - -- code size of the IR the inliner produces. - -- - -- Current status: Emitting bindings in the interest if "launch and - -- iterate"; have not tried `EmitToAtomCtx`. - ixsPref' <- mapM (inline $ EmitToNameCtx Stop) ixsPref - let ixsPref'' = [v | AtomVar v _ <- ixsPref'] - s <- getSubst - let moreSubst = bs @@> map Rename ixsPref'' - dropSubst $ extendSubst moreSubst do - -- Decision here. These decls have already been processed by the - -- inliner once, so their occurrence information is stale (and should - -- have been erased). Do we rerun occurrence analysis, or just complete - -- the pass without inlining any of them? - -- - Con rerunning: Slower - -- - Con completing: No detection of erroneous lack of occurrence info - -- For now went with "completing"; to detect erroneous lack of - -- occurrence info, change the relevant PlainLet cases above. - -- - -- There's also a missed opportunity here to do more inlining in one - -- pass: we lost the occurrence information of the bindings, so we lost - -- the ability to inline them into the result, so in the common case - -- that the result is a variable reference, we will find ourselves - -- emitting a rename, _which will inhibit downstream inlining_ because a - -- rename is not indexable. - inlineDecls decls do - let ctx' = TabAppCtx ixsRest s ctx - inlineAtom ctx' result - Nothing -> do - array' <- emitExprToAtom expr - ixs' <- mapM (inline Stop) ixs - reconstruct ctx =<< mkTabApp array' ixs' + => Context SExpr e o -> SExpr o -> SAtom i -> InlineM i o (e o) +reconstructTabApp ctx expr i = case expr of + PrimOp (Hof (TypedHof _ (For _ _ (UnaryLamExpr b body)))) -> do + -- See NoteReconstructTabAppDecisions + AtomVar i' _ <- inline (EmitToNameCtx Stop) i + dropSubst $ extendSubst (b@>Rename i') do + inlineExpr ctx body + _ -> do + array' <- emit expr + i' <- inline Stop i + reconstruct ctx =<< mkTabApp array' i' reconstructCase :: Emits o => Context SExpr e o -> SExpr o -> [SAlt i] -> SType i -> EffectRow SimpIR i @@ -418,23 +337,24 @@ reconstructCase ctx scrutExpr alts resultTy effs = -- of the arms of the outer case resultTy' <- inline Stop resultTy reconstruct ctx =<< (buildCase' sscrut resultTy' \i val -> do - ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emitBlock + ans <- applyAbs (sink $ salts !! i) (SubstVal val) >>= emit buildCase ans (sink resultTy') \j jval -> do Abs b body <- return $ alts !! j extendSubst (b @> (SubstVal $ DoneEx $ Atom jval)) do - inlineBlockEmits Stop body >>= emitExprToAtom) + inlineExpr Stop body >>= emit) _ -> do -- Attempt case-of-known-constructor optimization -- I can't use `buildCase` here because I want to propagate the incoming -- context `ctx` into the selected alternative if the optimization fires, -- but leave it around the whole reconstructed `Case` if it doesn't. - scrut <- emitExprToAtom scrutExpr - case trySelectBranch scrut of - Just (i, val) -> do + scrut <- emit scrutExpr + case scrut of + Con con -> do + SumCon _ i val <- return con Abs b body <- return $ alts !! i extendSubst (b @> (SubstVal $ DoneEx $ Atom val)) do - inlineBlockEmits ctx body - Nothing -> do + inlineExpr ctx body + Stuck _ _ -> do alts' <- mapM visitAlt alts resultTy' <- inline Stop resultTy effs' <- inline Stop effs @@ -442,3 +362,84 @@ reconstructCase ctx scrutExpr alts resultTy effs = instance Inlinable (EffectRow SimpIR) instance Inlinable (EffTy SimpIR) + +-- === NoteReconstructTabAppDecisions === + +-- There's a decision here. Is it ok to inline the atoms in `ixsPref` into the +-- body `decls`? If so, should we pre-process them and carry them in `DoneEx`, +-- or suspend them in `SuspEx`? (If not, we can emit fresh bindings and use +-- `Rename`.) We can't make this decision properly without annotating the `for` +-- binders with occurrence information; even though `ixsPref` itself are atoms, +-- we may be carrying suspended inlining decisions that would want to make one +-- an expression, and thus force-inlining it may duplicate work. +-- +-- There remains a decision between just emitting bindings, or running `mapM +-- (inline $ EmitToAtomCtx Stop)` and inlining the resulting atoms. In the +-- work-heavy case where an element of `ixsPref` becomes an expression after +-- inlining, the result will be the same; but in the work-light case where the +-- element remains an atom, more inlining can proceed. This decision only +-- affects the runtime of the inliner and the code size of the IR the inliner +-- produces. +-- +-- Current status: Emitting bindings in the interest if "launch and iterate"; +-- have not tried `EmitToAtomCtx`. Decision here. These decls have already been +-- processed by the inliner once, so their occurrence information is stale (and +-- should have been erased). Do we rerun occurrence analysis, or just complete +-- the pass without inlining any of them? +-- - Con rerunning: Slower +-- - Con completing: No detection of erroneous lack of occurrence info +-- For now went with "completing"; to detect erroneous lack of +-- occurrence info, change the relevant PlainLet cases above. +-- +-- There's also a missed opportunity here to do more inlining in one pass: we +-- lost the occurrence information of the bindings, so we lost the ability to +-- inline them into the result, so in the common case that the result is a +-- variable reference, we will find ourselves emitting a rename, _which will +-- inhibit downstream inlining_ because a rename is not indexable. + +-- === NoteConditionalInlining === + +-- TODO For now, this inliner does not do any conditional inlining. In order to +-- do it, we would need to augment the environment at this point, associating +-- name' to (expr', presInfo) so name' could be inlined at use sites. +-- +-- Conditional inlining is different in Dex vs Haskell because Dex is strict. To +-- wit, once we have emitted the bidning for `expr'`, we are committed to doing +-- the work it represents unless it's inlined _everywhere_. For example, +-- xs = +-- case of +-- Nothing -> xs -- ok to inline here +-- Just _ -> xs ... xs -- not ok here +-- If this were Haskell, it would be work-preserving for GHC to inline +-- `xs` into the `Nothing` arm, but in Dex it's not, unless we first +-- explicitly push the binding into the case like +-- case of +-- Nothing -> xs = ; xs +-- Just _ -> xs = ; xs ... xs +-- +-- That said, the Secrets paper says that GHC only conditionally inlines +-- zero-work bindings anyway (or, more precisely, "bounded finite work" +-- bindings). All the heuristics about whether to inline at a particular site +-- are about code size and not increasing it overmuch. But, of course, inlining +-- even zero-work bindings can help runtime performance because it can unblock +-- other optimizations that otherwise could not occur across the binding. + +-- === NoteSecretsSubtlety === + +-- A subtlety from the Secrets paper. In Haskell, it is feasible to have a +-- binding whose occurrence information indicates multiple uses, but which does +-- a small, bounded amount of runtime work. GHC will inline such a binding, but +-- not into contexts where GHC knows that no further optimizations are possible. +-- The example given in the paper is +-- f = \x -> E +-- g = \ys -> map f ys +-- Inlining f here is useless because it's not applied, and mildly costly +-- because it causes the closure to be allocated at every call to g rather than +-- just once. +-- TODO If we want to track this subtlety, we should make room for it in +-- the SizePreservationInfo ADT (maybe rename it), maybe with a +-- OnceButDuplicatesBoundedWork constructor. Then only the true UsedOnce +-- would be inlined unconditionally here, and the +-- OnceButDuplicatesBoundedWork constructor could be inlined or not +-- depending on its usage context. (This would correspond to the case +-- OnceUnsafe with whnfOrBot == True in the Secrets paper.) diff --git a/src/lib/JAX/ToSimp.hs b/src/lib/JAX/ToSimp.hs index e2e183955..7466d237b 100644 --- a/src/lib/JAX/ToSimp.hs +++ b/src/lib/JAX/ToSimp.hs @@ -17,6 +17,7 @@ import JAX.Concrete import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives qualified as P newtype JaxSimpM (i::S) (o::S) a = JaxSimpM @@ -66,10 +67,10 @@ simplifyJTy JArrayName{shape, dtype} = go shape $ simplifyDType dtype where simplifyDType :: DType -> Type r n simplifyDType = \case - F64 -> BaseTy $ P.Scalar P.Float64Type - F32 -> BaseTy $ P.Scalar P.Float32Type - I64 -> BaseTy $ P.Scalar P.Int64Type - I32 -> BaseTy $ P.Scalar P.Int32Type + F64 -> TyCon $ BaseType $ P.Scalar P.Float64Type + F32 -> TyCon $ BaseType $ P.Scalar P.Float32Type + I64 -> TyCon $ BaseType $ P.Scalar P.Int64Type + I32 -> TyCon $ BaseType $ P.Scalar P.Int32Type simplifyEqns :: Emits o => Nest JEqn i i' -> JaxSimpM i' o a -> JaxSimpM i o a simplifyEqns eqn cont = do @@ -104,7 +105,7 @@ simplifyAtom = \case SubstVal x -> return (x, ty) Rename nm' -> do nm'' <- toAtomVar nm' - return (Var nm'', ty) + return (toAtom nm'', ty) -- TODO In Jax, literals can presumably include (large) arrays. How should we -- represent them here? JLiteral (JLit {..}) -> return (Con (Lit (P.Float32Lit 0.0)), ty) @@ -122,7 +123,7 @@ unaryExpandRank :: forall i o. Emits o unaryExpandRank op arg JArrayName{shape} = go arg shape where go :: Emits l => SAtom l -> [DimSizeName] -> JaxSimpM i l (SAtom l) go arg' = \case - [] -> emitExprToAtom $ PrimOp (UnOp op arg') + [] -> emit $ PrimOp (UnOp op arg') (DimSize sz:rest) -> buildFor noHint P.Fwd (litFinIxTy sz) \i -> do - ixed <- mkTabApp (sink arg') [Var i] >>= emitExprToAtom + ixed <- mkTabApp (sink arg') (toAtom i) >>= emit go ixed rest diff --git a/src/lib/LLVM/CUDA.hs b/src/lib/LLVM/CUDA.hs index f0a6edca8..646fe59ec 100644 --- a/src/lib/LLVM/CUDA.hs +++ b/src/lib/LLVM/CUDA.hs @@ -40,7 +40,7 @@ import qualified Data.Set as S import LLVM.Compile import Types.Imp -import Types.Misc +import Types.Source data LLVMKernel = LLVMKernel L.Module diff --git a/src/lib/LLVM/Compile.hs b/src/lib/LLVM/Compile.hs index c3d690ad9..b4248bbf3 100644 --- a/src/lib/LLVM/Compile.hs +++ b/src/lib/LLVM/Compile.hs @@ -30,15 +30,10 @@ import System.IO.Unsafe import Control.Monad -import Logging import PPrint () import Paths_dex (getDataFileName) -import Types.Misc --- The only reason this module depends on Types.Source is that we pass in the logger, --- in order to optionally print out the IRs. LLVM mutates its IRs in-place, so --- we can't just expose a functional API for each stage without taking a --- performance hit. But maybe the performance hit isn't so bad? import Types.Source +import MonadUtil data LLVMOptLevel = OptALittle -- -O1 @@ -110,7 +105,11 @@ standardCompilationPipeline opt logger exports tm m = do {-# SCC showAssembly #-} logPass AsmPass $ showAsm tm m where logPass :: PassName -> IO String -> IO () - logPass passName cont = logFiltered logger passName $ cont >>= \s -> return [PassInfo passName s] + logPass passName showIt = case ioLogLevel logger of + DebugLogLevel -> do + s <- showIt + ioLogAction logger $ Outputs [PassInfo passName s] + NormalLogLevel -> return () {-# SCC standardCompilationPipeline #-} internalize :: [String] -> Mod.Module -> IO () diff --git a/src/lib/Lexing.hs b/src/lib/Lexing.hs index 0f0fc3ddb..ec916f749 100644 --- a/src/lib/Lexing.hs +++ b/src/lib/Lexing.hs @@ -15,6 +15,7 @@ import Data.Text (Text) import Data.Text qualified as T import Data.Void import Data.Word +import qualified Data.Map.Strict as M import Text.Megaparsec hiding (Label, State) import Text.Megaparsec.Char hiding (space, eol) @@ -23,23 +24,27 @@ import qualified Text.Megaparsec.Char.Lexer as L import Text.Megaparsec.Debug import Err -import SourceInfo +import PPrint import Types.Primitives +import Types.Source +import Util (toSnocList) data ParseCtx = ParseCtx { curIndent :: Int -- used Reader-style (i.e. ask/local) , canBreak :: Bool -- used Reader-style (i.e. ask/local) , prevWhitespace :: Bool -- tracks whether we just consumed whitespace - } + , sourceIdCounter :: Int -- starts at 1 (0 is reserved for the root) + , curAtomicLexemes :: [SrcId] + , curLexemeInfo :: LexemeInfo } -- append to, writer-style initParseCtx :: ParseCtx -initParseCtx = ParseCtx 0 False False +initParseCtx = ParseCtx 0 False False 1 mempty mempty type Parser = StateT ParseCtx (Parsec Void Text) parseit :: Text -> Parser a -> Except a parseit s p = case parse (fst <$> runStateT p initParseCtx) "" s of - Left e -> throw ParseErr $ errorBundlePretty e + Left e -> throwErr $ ParseErr $ MiscParseErr $ errorBundlePretty e Right x -> return x mustParseit :: Text -> Parser a -> a @@ -63,12 +68,15 @@ nextChar = do return $ T.head i {-# INLINE nextChar #-} -anyCaseName :: Lexer SourceName -anyCaseName = label "name" $ lexeme $ - checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> +anyCaseName :: Lexer (WithSrc SourceName) +anyCaseName = label "name" $ lexeme LowerName anyCaseName' -- TODO: distinguish lowercase/uppercase + +anyCaseName' :: Lexer SourceName +anyCaseName' = + liftM MkSourceName $ checkNotKeyword $ (:) <$> satisfy (\c -> isLower c || isUpper c) <*> (T.unpack <$> takeWhileP Nothing (\c -> isAlphaNum c || c == '\'' || c == '_')) -anyName :: Lexer SourceName +anyName :: Lexer (WithSrc SourceName) anyName = anyCaseName <|> symName checkNotKeyword :: Parser String -> Parser String @@ -121,8 +129,11 @@ keyWordToken = \case PassKW -> "pass" keyWord :: KeyWord -> Lexer () -keyWord kw = lexeme $ try $ string (fromString $ keyWordToken kw) - >> notFollowedBy nameTailChar +keyWord kw = atomicLexeme Keyword $ try $ + string (fromString $ keyWordToken kw) >> notFollowedBy nameTailChar + where + nameTailChar :: Parser Char + nameTailChar = alphaNumChar <|> char '\'' <|> char '_' keyWordSet :: HS.HashSet String keyWordSet = HS.fromList keyWordStrs @@ -130,20 +141,20 @@ keyWordSet = HS.fromList keyWordStrs keyWordStrs :: [String] keyWordStrs = map keyWordToken [DefKW .. PassKW] -primName :: Lexer String -primName = lexeme $ try $ char '%' >> some alphaNumChar +primName :: Lexer (WithSrc String) +primName = lexeme MiscLexeme $ try $ char '%' >> some alphaNumChar -charLit :: Lexer Char -charLit = lexeme $ char '\'' >> L.charLiteral <* char '\'' +charLit :: Lexer (WithSrc Char) +charLit = lexeme MiscLexeme $ char '\'' >> L.charLiteral <* char '\'' -strLit :: Lexer String -strLit = lexeme $ char '"' >> manyTill L.charLiteral (char '"') +strLit :: Lexer (WithSrc String) +strLit = lexeme StringLiteralLexeme $ char '"' >> manyTill L.charLiteral (char '"') -natLit :: Lexer Word64 -natLit = lexeme $ try $ L.decimal <* notFollowedBy (char '.') +natLit :: Lexer (WithSrc Word64) +natLit = lexeme LiteralLexeme $ try $ L.decimal <* notFollowedBy (char '.') -doubleLit :: Lexer Double -doubleLit = lexeme $ +doubleLit :: Lexer (WithSrc Double) +doubleLit = lexeme LiteralLexeme $ try L.float <|> try (fromIntegral <$> (L.decimal :: Parser Int) <* char '.') <|> try do @@ -156,27 +167,33 @@ knownSymStrs :: HS.HashSet String knownSymStrs = HS.fromList [ ".", ":", "::", "!", "=", "-", "+", "||", "&&" , "$", "&>", "|", ",", ",>", "<-", "+=", ":=" - , "->", "->>", "=>", "?->", "?=>", "--o", "--", "<<<", ">>>" + , "->", "->>", "=>", "?->", "?=>", "<<<", ">>>" , "..", "<..", "..<", "..<", "<..<", "?", "#", "##", "#?", "#&", "#|", "@"] --- string must be in `knownSymStrs` sym :: Text -> Lexer () -sym s = lexeme $ try $ string s >> notFollowedBy symChar +sym s = atomicLexeme Symbol $ sym' s + +symWithId :: Text -> Lexer SrcId +symWithId s = liftM srcPos $ lexeme Symbol $ sym' s + +-- string must be in `knownSymStrs` +sym' :: Text -> Lexer () +sym' s = void $ try $ string s >> notFollowedBy symChar -anySym :: Lexer String -anySym = lexeme $ try $ do +anySym :: Lexer (WithSrc String) +anySym = lexeme Symbol $ try $ do s <- some symChar failIf (s `HS.member` knownSymStrs) "" return s -symName :: Lexer SourceName -symName = label "symbol name" $ lexeme $ try $ do +symName :: Lexer (WithSrc SourceName) +symName = label "symbol name" $ lexeme Symbol $ try $ do s <- between (char '(') (char ')') $ some symChar - return $ "(" <> s <> ")" + return $ MkSourceName $ "(" <> s <> ")" -backquoteName :: Lexer SourceName +backquoteName :: Lexer (WithSrc SourceName) backquoteName = label "backquoted name" $ - lexeme $ try $ between (char '`') (char '`') anyCaseName + lexeme Symbol $ try $ between (char '`') (char '`') anyCaseName' -- brackets and punctuation -- (can't treat as sym because e.g. `((` is two separate lexemes) @@ -192,10 +209,7 @@ semicolon = charLexeme ';' underscore = charLexeme '_' charLexeme :: Char -> Parser () -charLexeme c = void $ lexeme $ char c - -nameTailChar :: Parser Char -nameTailChar = alphaNumChar <|> char '\'' <|> char '_' +charLexeme c = atomicLexeme Symbol $ void $ char c symChar :: Parser Char symChar = token (\c -> if HS.member c symChars then Just c else Nothing) mempty @@ -203,6 +217,10 @@ symChar = token (\c -> if HS.member c symChars then Just c else Nothing) mempty symChars :: HS.HashSet Char symChars = HS.fromList ".,!$^&*:-~+/=<>|?\\@#" +-- XXX: unlike other lexemes, this doesn't consume trailing whitespace +dot :: Parser SrcId +dot = srcPos <$> lexeme' (return ()) Symbol (void $ char '.') + -- === Util === sc :: Parser () @@ -210,9 +228,7 @@ sc = (skipSome s >> recordWhitespace) <|> return () where s = hidden space <|> hidden lineComment lineComment :: Parser () -lineComment = do - try $ string "--" >> notFollowedBy (void (char 'o')) - void (takeWhileP (Just "char") (/= '\n')) +lineComment = string "#" >> void (takeWhileP (Just "char") (/= '\n')) outputLines :: Parser () outputLines = void $ many (symbol ">" >> takeWhileP Nothing (/= '\n') >> ((eol >> return ()) <|> eof)) @@ -222,12 +238,21 @@ space = gets canBreak >>= \case True -> space1 False -> void $ takeWhile1P (Just "white space") (`elem` (" \t" :: String)) +setCanBreakLocally :: Bool -> Parser a -> Parser a +setCanBreakLocally brLocal p = do + brPrev <- gets canBreak + modify \ctx -> ctx {canBreak = brLocal} + ans <- p + modify \ctx -> ctx {canBreak = brPrev} + return ans +{-# INLINE setCanBreakLocally #-} + mayBreak :: Parser a -> Parser a -mayBreak p = pLocal (\ctx -> ctx { canBreak = True }) p +mayBreak p = setCanBreakLocally True p {-# INLINE mayBreak #-} mayNotBreak :: Parser a -> Parser a -mayNotBreak p = pLocal (\ctx -> ctx { canBreak = False }) p +mayNotBreak p = setCanBreakLocally False p {-# INLINE mayNotBreak #-} precededByWhitespace :: Parser Bool @@ -243,35 +268,23 @@ recordNonWhitespace = modify \ctx -> ctx { prevWhitespace = False } {-# INLINE recordNonWhitespace #-} nameString :: Parser String -nameString = lexeme . try $ (:) <$> lowerChar <*> many alphaNumChar +nameString = lexemeIgnoreSrcId LowerName . try $ (:) <$> lowerChar <*> many alphaNumChar thisNameString :: Text -> Parser () -thisNameString s = lexeme $ try $ string s >> notFollowedBy alphaNumChar +thisNameString s = lexemeIgnoreSrcId MiscLexeme $ try $ string s >> notFollowedBy alphaNumChar bracketed :: Parser () -> Parser () -> Parser a -> Parser a -bracketed left right p = between left right $ mayBreak $ sc >> p +bracketed left right p = do + left + ans <- mayBreak $ sc >> p + right + return ans {-# INLINE bracketed #-} -parens :: Parser a -> Parser a -parens p = bracketed lParen rParen p -{-# INLINE parens #-} - -brackets :: Parser a -> Parser a -brackets p = bracketed lBracket rBracket p -{-# INLINE brackets #-} - braces :: Parser a -> Parser a braces p = bracketed lBrace rBrace p {-# INLINE braces #-} -withPos :: Parser a -> Parser (a, SrcPos) -withPos p = do - n <- getOffset - x <- p - n' <- getOffset - return $ (x, (n, n')) -{-# INLINE withPos #-} - nextLine :: Parser () nextLine = do eol @@ -282,7 +295,9 @@ nextLine = do withSource :: Parser a -> Parser (Text, a) withSource p = do s <- getInput - (x, (start, end)) <- withPos p + start <- getOffset + x <- p + end <- getOffset return (T.take (end - start) s, x) {-# INLINE withSource #-} @@ -291,14 +306,16 @@ withIndent p = do nextLine indent <- T.length <$> takeWhileP (Just "space") (==' ') when (indent <= 0) empty - pLocal (\ctx -> ctx { curIndent = curIndent ctx + indent }) $ mayNotBreak p + locallyExtendCurIndent indent $ mayNotBreak p {-# INLINE withIndent #-} -pLocal :: (ParseCtx -> ParseCtx) -> Parser a -> Parser a -pLocal f p = do - s <- get - put (f s) >> p <* put s -{-# INLINE pLocal #-} +locallyExtendCurIndent :: Int -> Parser a -> Parser a +locallyExtendCurIndent n p = do + indentPrev <- gets curIndent + modify \ctx -> ctx { curIndent = indentPrev + n } + ans <- p + modify \ctx -> ctx { curIndent = indentPrev } + return ans eol :: Parser () eol = void MC.eol @@ -310,10 +327,59 @@ failIf :: Bool -> String -> Parser () failIf True s = fail s failIf False _ = return () -lexeme :: Parser a -> Parser a -lexeme p = L.lexeme sc (p <* recordNonWhitespace) -{-# INLINE lexeme #-} +freshSrcId :: Parser SrcId +freshSrcId = do + c <- gets sourceIdCounter + modify \ctx -> ctx { sourceIdCounter = c + 1 } + return $ SrcId c + +withLexemeInfo :: Parser a -> Parser (LexemeInfo, a) +withLexemeInfo cont = do + smPrev <- gets curLexemeInfo + modify \ctx -> ctx { curLexemeInfo = mempty } + result <- cont + sm <- gets curLexemeInfo + modify \ctx -> ctx { curLexemeInfo = smPrev } + return (sm, result) + +emitLexemeInfo :: LexemeInfo -> Parser () +emitLexemeInfo m = modify \ctx -> ctx { curLexemeInfo = curLexemeInfo ctx <> m } + +lexemeIgnoreSrcId :: LexemeType -> Parser a -> Parser a +lexemeIgnoreSrcId lexemeType p = withoutSrc <$> lexeme lexemeType p symbol :: Text -> Parser () symbol s = void $ L.symbol sc s +lexeme :: LexemeType -> Parser a -> Parser (WithSrc a) +lexeme lexemeType p = lexeme' sc lexemeType p +{-# INLINE lexeme #-} + +lexeme' :: Parser () -> LexemeType -> Parser a -> Parser (WithSrc a) +lexeme' sc' lexemeType p = do + start <- getOffset + ans <- p + end <- getOffset + recordNonWhitespace + sc' + sid <- freshSrcId + emitLexemeInfo $ mempty + { lexemeList = toSnocList [sid] + , lexemeInfo = M.singleton sid (lexemeType, (start, end)) } + return $ WithSrc sid ans +{-# INLINE lexeme' #-} + +atomicLexeme :: LexemeType -> Parser () -> Parser () +atomicLexeme lexemeType p = do + WithSrc sid () <- lexeme lexemeType p + modify \ctx -> ctx { curAtomicLexemes = curAtomicLexemes ctx ++ [sid] } +{-# INLINE atomicLexeme #-} + +collectAtomicLexemeIds :: Parser a -> Parser ([SrcId], a) +collectAtomicLexemeIds p = do + prevAtomicLexemes <- gets curAtomicLexemes + modify \ctx -> ctx { curAtomicLexemes = [] } + ans <- p + localLexemes <- gets curAtomicLexemes + modify \ctx -> ctx { curAtomicLexemes = prevAtomicLexemes } + return (localLexemes, ans) diff --git a/src/lib/Linearize.hs b/src/lib/Linearize.hs index d32b5230a..ee61d8437 100644 --- a/src/lib/Linearize.hs +++ b/src/lib/Linearize.hs @@ -8,7 +8,6 @@ module Linearize (linearize, linearizeTopLam) where import Control.Category ((>>>)) import Control.Monad.Reader -import Data.Foldable (toList) import Data.Functor import Data.List (elemIndex) import Data.Maybe (catMaybes, isJust) @@ -28,7 +27,8 @@ import PPrint import QueryType import Types.Core import Types.Primitives -import Util (bindM2, enumerate) +import Types.Top +import Util (enumerate) -- === linearization monad === @@ -84,7 +84,7 @@ extendActivePrimalss vs = local \primals -> primals { activeVars = activeVars primals ++ vs } getTangentArg :: Int -> TangentM o (Atom SimpIR o) -getTangentArg idx = asks \(TangentArgs vs) -> Var $ vs !! idx +getTangentArg idx = asks \(TangentArgs vs) -> toAtom $ vs !! idx extendTangentArgs :: SAtomVar n -> TangentM n a -> TangentM n a extendTangentArgs v m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ [v]) m @@ -95,48 +95,39 @@ extendTangentArgss vs' m = local (\(TangentArgs vs) -> TangentArgs $ vs ++ vs') getTangentArgs :: TangentM o (TangentArgs o) getTangentArgs = ask -bindLin - :: Emits o - => LinM i o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> LinM i o e' e' -bindLin m f = do - result <- m - withBoth result f - -withBoth - :: Emits o +emitBoth + :: (Emits o, ToExpr e' SimpIR) => WithTangent o e e - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) - -> PrimalM i o (WithTangent o e' e') -withBoth (WithTangent x tx) f = do + -> (forall o' m. (DExt o o', Builder SimpIR m) => e o' -> m o' (e' o')) + -> LinM i o SAtom SAtom +emitBoth (WithTangent x tx) f = do Distinct <- getDistinct - y <- f x - return $ WithTangent y do - tx >>= f + x' <- emit =<< f x + return $ WithTangent x' do + tx' <- tx + emitLin =<< f tx' -_withTangentComputation - :: Emits o - => WithTangent o e1 e2 - -> (forall o' m. (Emits o', DExt o o', Builder SimpIR m) => e2 o' -> m o' (e2' o')) - -> PrimalM i o (WithTangent o e1 e2') -_withTangentComputation (WithTangent x tx) f = do - Distinct <- getDistinct - return $ WithTangent x do - tx >>= f +emitZeroT :: (Emits o, HasNamesE e', ToExpr e' SimpIR) => e' i -> LinM i o SAtom SAtom +emitZeroT e = do + x <- emit =<< renameM e + return $ WithTangent x (zeroLikeT x) + +zeroLikeT :: (DExt o o', Emits o', HasType SimpIR e) => e o -> TangentM o' (SAtom o') +zeroLikeT x = do + ty <- sinkM $ getType x + zeroAt =<< tangentType ty fmapLin :: Emits o => (forall o'. e o' -> e' o') -> LinM i o e e -> LinM i o e' e' -fmapLin f m = m `bindLin` (pure . f) +fmapLin f m = do + WithTangent ans tx <- m + return $ WithTangent (f ans) (f <$> tx) -zipLin :: LinM i o e1 e1 -> LinM i o e2 e2 -> LinM i o (PairE e1 e2) (PairE e1 e2) -zipLin m1 m2 = do - WithTangent x1 t1 <- m1 - WithTangent x2 t2 <- m2 - return $ WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 +zipLin :: WithTangent o e1 e1 -> WithTangent o e2 e2 -> WithTangent o (PairE e1 e2) (PairE e1 e2) +zipLin (WithTangent x1 t1) (WithTangent x2 t2) = WithTangent (PairE x1 x2) do PairE <$> t1 <*> t2 seqLin :: Traversable f @@ -191,17 +182,17 @@ getTangentArgTys topVs = go mempty topVs where -- like this, but there's nothing to prevent users writing programs that -- sling around heap variables by themselves. We should try to do something -- better... - TC HeapType -> do - withFreshBinder (getNameHint v) (TC HeapType) \hb -> do + TyCon HeapType -> do + withFreshBinder (getNameHint v) (TyCon HeapType) \hb -> do let newHeapMap = sink heapMap <> eMapSingleton (sink (atomVarName v)) (binderVar hb) Abs bs UnitE <- go newHeapMap $ sinkList vs return $ EmptyAbs $ Nest hb bs - RefTy (Var h) referentTy -> do + RefTy (Stuck _ (Var h)) referentTy -> do case lookupEMap heapMap (atomVarName h) of Nothing -> error "shouldn't happen?" Just h' -> do tt <- tangentType referentTy - let refTy = RefTy (Var h') tt + let refTy = RefTy (toAtom h') tt withFreshBinder (getNameHint v) refTy \refb -> do Abs bs UnitE <- go (sink heapMap) $ sinkList vs return $ EmptyAbs $ Nest refb bs @@ -258,28 +249,29 @@ instance ReconFunctor ObligateReconAbs where linLam' <- applyReconAbs reconAbs residuals return (primal, linLam') -linearizeBlockDefunc :: SBlock i -> PrimalM i o (SBlock o, LinLamAbs o) -linearizeBlockDefunc = linearizeBlockDefuncGeneral emptyOutFrag +linearizeExprDefunc :: SExpr i -> PrimalM i o (SExpr o, LinLamAbs o) +linearizeExprDefunc = linearizeExprDefuncGeneral emptyOutFrag -linearizeBlockDefuncGeneral +linearizeExprDefuncGeneral :: ReconFunctor f - => ScopeFrag o' o -> SBlock i -> PrimalM i o (SBlock o, f SLam o') -linearizeBlockDefuncGeneral locals block = do + => ScopeFrag o' o -> SExpr i -> PrimalM i o (SExpr o, f SLam o') +linearizeExprDefuncGeneral locals expr = do Abs decls result <- buildScoped do - WithTangent primalResult tangentFun <- linearizeBlock block + WithTangent primalResult tangentFun <- linearizeExpr expr lam <- tangentFunAsLambda tangentFun return $ PairE primalResult lam - (block', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do + (Abs decls' result', recon) <- refreshAbs (Abs decls result) \decls' (PairE primal lam) -> do (primal', recon) <- capture (locals >>> toScopeFrag decls') primal lam return (Abs decls' primal', recon) - return (block', recon) + block <- mkBlock (Abs decls' result') + return (block, recon) -- Inverse of tangentFunAsLambda. Should be used inside a returned tangent action. applyLinLam :: Emits o => SLam i -> SubstReaderT AtomSubstVal TangentM i o (Atom SimpIR o) applyLinLam (LamExpr bs body) = do TangentArgs args <- liftSubstReaderT $ getTangentArgs extendSubst (bs @@> ((Rename . atomVarName) <$> args)) do - substM body >>= emitBlock + substM body >>= emitLin -- === actual linearization passs === @@ -296,19 +288,19 @@ linearizeTopLam (TopLam False _ (LamExpr bs body)) actives = do True -> return $ Just v False -> return $ Nothing (body', linLamAbs) <- extendActivePrimalss activeVs do - linearizeBlockDefuncGeneral emptyOutFrag body + linearizeExprDefuncGeneral emptyOutFrag body let primalFun = LamExpr bs' body' ObligateRecon ty (Abs bsRecon (LamExpr bsTangent tangentBody)) <- return linLamAbs tangentFun <- withFreshBinder "residuals" ty \bResidual -> do - xs <- unpackTelescope bsRecon $ Var $ binderVar bResidual + xs <- unpackTelescope bsRecon $ toAtom $ binderVar bResidual Abs bsTangent' UnitE <- applySubst (bsRecon @@> map SubstVal xs) (Abs bsTangent UnitE) - tangentTy <- ProdTy <$> typesFromNonDepBinderNest bsTangent' + tangentTy <- TyCon <$> ProdType <$> typesFromNonDepBinderNest bsTangent' withFreshBinder "t" tangentTy \bTangent -> do tangentBody' <- buildBlock do - ts <- getUnpacked $ Var $ sink $ binderVar bTangent + ts <- getUnpacked $ toAtom $ sink $ binderVar bTangent let substFrag = bsRecon @@> map (SubstVal . sink) xs <.> bsTangent @@> map (SubstVal . sink) ts - emitBlock =<< applySubst substFrag tangentBody + emitLin =<< applySubst substFrag tangentBody return $ LamExpr (bs' >>> BinaryNest bResidual bTangent) tangentBody' return (primalFun, tangentFun) (,) <$> asTopLam primalFun <*> asTopLam tangentFun @@ -317,35 +309,37 @@ linearizeTopLam (TopLam True _ _) _ = error "expected a non-destination-passing -- reify the tangent builder as a lambda linearizeLambdaApp :: Emits o => SLam i -> SAtom o -> PrimalM i o (SAtom o, SLam o) linearizeLambdaApp (UnaryLamExpr b body) x = do - vp <- emit $ Atom x + vp <- emitToVar x extendActiveSubst b vp do - WithTangent primalResult tangentAction <- linearizeBlock body + WithTangent primalResult tangentAction <- linearizeExpr body tanFun <- tangentFunAsLambda tangentAction return (primalResult, tanFun) linearizeLambdaApp _ _ = error "not implemented" linearizeAtom :: Emits o => Atom SimpIR i -> LinM i o SAtom SAtom -linearizeAtom atom = case atom of +linearizeAtom (Con con) = linearizePrimCon con +linearizeAtom (Stuck _ stuck) = linearizeStuck stuck + +linearizeStuck :: Emits o => Stuck SimpIR i -> LinM i o SAtom SAtom +linearizeStuck stuck = case stuck of Var v -> do v' <- renameM v activePrimalIdx v' >>= \case - Nothing -> withZeroT $ return (Var v') - Just idx -> return $ WithTangent (Var v') $ getTangentArg idx - Con con -> linearizePrimCon con - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> emitZeroT - ProjectElt _ i x -> do - WithTangent x' tx <- linearizeAtom x - xi <- normalizeProj i x' - return $ WithTangent xi do - t <- tx - normalizeProj i t - RepValAtom _ -> emitZeroT - where emitZeroT = withZeroT $ renameM atom - -linearizeBlock :: Emits o => SBlock i -> LinM i o SAtom SAtom -linearizeBlock (Abs decls result) = - linearizeDecls decls $ linearizeAtom result + Nothing -> zero + Just idx -> return $ WithTangent (toAtom v') $ getTangentArg idx + PtrVar _ _ -> zero + RepValAtom _ -> zero + -- TODO: de-dup with the Expr versions of these + StuckProject i x -> do + x' <- linearizeStuck x + emitBoth x' \x'' -> mkProject i x'' + StuckTabApp x i -> do + pt <- zipLin <$> linearizeStuck x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' + where + zero = do + atom <- mkStuck =<< renameM stuck + return $ WithTangent atom (zeroLikeT atom) linearizeDecls :: Emits o => Nest SDecl i i' -> LinM i' o e1 e2 -> LinM i o e1 e2 linearizeDecls Empty cont = cont @@ -356,7 +350,7 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do expr' <- renameM expr isTrivialForAD expr' >>= \case True -> do - v <- emit expr' + v <- emitToVar expr' extendSubst (b@>atomVarName v) $ linearizeDecls rest cont False -> do WithTangent p tf <- linearizeExpr expr @@ -365,13 +359,14 @@ linearizeDecls (Nest (Let b (DeclBinding ann expr)) rest) cont = do WithTangent pRest tfRest <- linearizeDecls rest cont return $ WithTangent pRest do t <- tf - vt <- emitDecl (getNameHint b) ann (Atom t) + vt <- emitDecl (getNameHint b) LinearLet (Atom t) extendTangentArgs vt $ tfRest linearizeExpr :: Emits o => SExpr i -> LinM i o SAtom SAtom linearizeExpr expr = case expr of Atom x -> linearizeAtom x + Block _ (Abs decls result) -> linearizeDecls decls $ linearizeExpr result TopApp _ f xs -> do (xs', ts) <- unzip <$> forM xs \x -> do x' <- renameM x @@ -391,14 +386,11 @@ linearizeExpr expr = case expr of (ans, residuals) <- fromPair =<< naryTopApp fPrimal xs' return $ WithTangent ans do ts' <- forM (catMaybes ts) \(WithTangent UnitE t) -> t - naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals, ProdVal ts']) + naryTopApp (sink fTan) (sinkList xs' ++ [sink residuals, Con $ ProdCon ts']) where unitLike :: e n -> UnitE n unitLike _ = UnitE - TabApp _ x idxs -> do - zipLin (linearizeAtom x) (pureLin $ ListE $ toList idxs) `bindLin` - \(PairE x' (ListE idxs')) -> naryTabApp x' idxs' - PrimOp op -> linearizeOp op + PrimOp op -> linearizeOp op Case e alts (EffTy effs resultTy) -> do e' <- renameM e effs' <- renameM effs @@ -409,60 +401,71 @@ linearizeExpr expr = case expr of (alts', recons) <- unzip <$> buildCaseAlts e' \i b' -> do Abs b body <- return $ alts !! i extendSubst (b@>binderName b') do - (block, recon) <- linearizeBlockDefuncGeneral (toScopeFrag b') body + (block, recon) <- linearizeExprDefuncGeneral (toScopeFrag b') body return (Abs b' block, recon) let tys = recons <&> \(ObligateRecon t _) -> t alts'' <- forM (enumerate alts') \(i, alt) -> do injectAltResult tys i alt - let fullResultTy = PairTy resultTy' $ SumTy tys - result <- emitExpr $ Case e' alts'' (EffTy effs' fullResultTy) + let fullResultTy = PairTy resultTy' $ TyCon $ SumType tys + result <- emit $ Case e' alts'' (EffTy effs' fullResultTy) (primal, residualss) <- fromPair result resultTangentType <- tangentType resultTy' return $ WithTangent primal do - buildCase (sink residualss) (sink resultTangentType) \i residuals -> do + emitLin =<< buildCase' (sink residualss) (sink resultTangentType) \i residuals -> do ObligateRecon _ (Abs bs linLam) <- return $ sinkList recons !! i residuals' <- unpackTelescope bs residuals withSubstReaderT $ extendSubst (bs @@> (SubstVal <$> residuals')) do applyLinLam linLam TabCon _ ty xs -> do ty' <- renameM ty - seqLin (map linearizeAtom xs) `bindLin` \(ComposeE xs') -> - emitExpr $ TabCon Nothing (sink ty') xs' + pt <- seqLin (map linearizeAtom xs) + emitBoth pt \(ComposeE xs') -> return $ TabCon Nothing (sink ty') xs' + TabApp _ x i -> do + pt <- zipLin <$> linearizeAtom x <*> pureLin i + emitBoth pt \(PairE x' i') -> mkTabApp x' i' + Project _ i x -> do + x' <- linearizeAtom x + emitBoth x' \x'' -> mkProject i x'' linearizeOp :: Emits o => PrimOp SimpIR i -> LinM i o SAtom SAtom linearizeOp op = case op of Hof (TypedHof _ e) -> linearizeHof e DAMOp _ -> error "shouldn't occur here" - RefOp ref m -> case m of - MAsk -> linearizeAtom ref `bindLin` \ref' -> liftM Var $ emit $ PrimOp $ RefOp ref' MAsk - MExtend monoid x -> do - -- TODO: check that we're dealing with a +/0 monoid - monoid' <- renameM monoid - zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM Var $ emit $ PrimOp $ RefOp ref' $ MExtend (sink monoid') x' - MGet -> linearizeAtom ref `bindLin` \ref' -> liftM Var $ emit $ PrimOp $ RefOp ref' MGet - MPut x -> zipLin (linearizeAtom ref) (linearizeAtom x) `bindLin` \(PairE ref' x') -> - liftM Var $ emit $ PrimOp $ RefOp ref' $ MPut x' - IndexRef _ i -> do - zipLin (la ref) (pureLin i) `bindLin` \(PairE ref' i') -> - emitOp =<< mkIndexRef ref' i' - ProjRef _ i -> la ref `bindLin` \ref' -> emitOp =<< mkProjRef ref' i + RefOp ref m -> do + ref' <- linearizeAtom ref + case m of + MAsk -> emitBoth ref' \ref'' -> return $ RefOp ref'' MAsk + MExtend monoid x -> do + -- TODO: check that we're dealing with a +/0 monoid + monoid' <- renameM monoid + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> + return $ RefOp ref'' $ MExtend (sink monoid') x'' + MGet -> emitBoth ref' \ref'' -> return $ RefOp ref'' MGet + MPut x -> do + x' <- linearizeAtom x + emitBoth (zipLin ref' x') \(PairE ref'' x'') -> return $ RefOp ref'' $ MPut x'' + IndexRef _ i -> do + i' <- pureLin i + emitBoth (zipLin ref' i') \(PairE ref'' i'') -> mkIndexRef ref'' i'' + ProjRef _ i -> emitBoth ref' \ref'' -> mkProjRef ref'' i UnOp uop x -> linearizeUnOp uop x BinOp bop x y -> linearizeBinOp bop x y -- XXX: This assumes that pointers are always constants - MemOp _ -> emitZeroT + MemOp _ -> emitZeroT op MiscOp miscOp -> linearizeMiscOp miscOp VectorOp _ -> error "not implemented" - where - emitZeroT = withZeroT $ liftM Var $ emit =<< renameM (PrimOp op) - la = linearizeAtom linearizeMiscOp :: Emits o => MiscOp SimpIR i -> LinM i o SAtom SAtom linearizeMiscOp op = case op of - SumTag _ -> emitZeroT - ToEnum _ _ -> emitZeroT - Select p t f -> (pureLin p `zipLin` la t `zipLin` la f) `bindLin` - \(p' `PairE` t' `PairE` f') -> emitOp $ MiscOp $ Select p' t' f' + SumTag _ -> zero + ToEnum _ _ -> zero + Select p t f -> do + p' <- pureLin p + t' <- linearizeAtom t + f' <- linearizeAtom f + emitBoth (p' `zipLin` t' `zipLin` f') + \(p'' `PairE` t'' `PairE` f'') -> return $ Select p'' t'' f'' CastOp t v -> do vt <- getType <$> renameM v t' <- renameM t @@ -471,92 +474,105 @@ linearizeMiscOp op = case op of ((&&) <$> (vtTangentType `alphaEq` vt) <*> (tTangentType `alphaEq` t')) >>= \case True -> do - linearizeAtom v `bindLin` \v' -> emitOp $ MiscOp $ CastOp (sink t') v' + v' <- linearizeAtom v + emitBoth v' \v'' -> return $ CastOp (sink t') v'' False -> do WithTangent x xt <- linearizeAtom v yt <- case (vtTangentType, tTangentType) of (_ , UnitTy) -> return $ UnitVal (UnitTy, tt ) -> zeroAt tt _ -> error "Expected at least one side of the CastOp to have a trivial tangent type" - y <- emitOp $ MiscOp $ CastOp t' x + y <- emit $ CastOp t' x return $ WithTangent y do xt >> return (sink yt) BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented ThrowException _ -> notImplemented - ThrowError _ -> emitZeroT - OutputStream -> emitZeroT + ThrowError _ -> zero + OutputStream -> zero ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - emitZeroT = withZeroT $ liftM Var $ emit =<< renameM (PrimOp $ MiscOp op) - la = linearizeAtom + where zero = emitZeroT op linearizeUnOp :: Emits o => UnOp -> Atom SimpIR i -> LinM i o SAtom SAtom -linearizeUnOp op x' = do - WithTangent x tx <- linearizeAtom x' - let emitZeroT = withZeroT $ emitOp $ UnOp op x - case op of - Exp -> do - y <- emitUnOp Exp x - return $ WithTangent y (bindM2 mul tx (sinkM y)) - Exp2 -> notImplemented - Log -> withT (emitUnOp Log x) $ (tx >>= (`div'` sink x)) - Log2 -> notImplemented - Log10 -> notImplemented - Log1p -> notImplemented - Sin -> withT (emitUnOp Sin x) $ bindM2 mul tx (emitUnOp Cos (sink x)) - Cos -> withT (emitUnOp Cos x) $ bindM2 mul tx (neg =<< emitUnOp Sin (sink x)) - Tan -> notImplemented - Sqrt -> do - y <- emitUnOp Sqrt x - return $ WithTangent y do - denominator <- bindM2 mul (2 `fLitLike` sink y) (sinkM y) - bindM2 div' tx (pure denominator) - Floor -> emitZeroT - Ceil -> emitZeroT - Round -> emitZeroT - LGamma -> notImplemented - Erf -> notImplemented - Erfc -> notImplemented - FNeg -> withT (neg x) (neg =<< tx) - BNot -> emitZeroT +linearizeUnOp op x'' = do + WithTangent x' tx' <- linearizeAtom x'' + ans' <- emit $ UnOp op x' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- sinkM x' + tx <- tx' + let zero = zeroLikeT ans + case op of + Exp -> emitLin $ BinOp FMul tx ans + Exp2 -> notImplemented + Log -> emitLin $ BinOp FDiv tx x + Log2 -> notImplemented + Log10 -> notImplemented + Log1p -> notImplemented + Sin -> do + c <- emit $ UnOp Cos x + emitLin $ BinOp FMul tx c + Cos -> do + c <- emit =<< (UnOp FNeg <$> emit (UnOp Sin x)) + emitLin $ BinOp FMul tx c + Tan -> notImplemented + Sqrt -> do + denominator <- fmul (2 `fLitLike` ans) ans + emitLin $ BinOp FDiv tx denominator + Floor -> zero + Ceil -> zero + Round -> zero + LGamma -> notImplemented + Erf -> notImplemented + Erfc -> notImplemented + FNeg -> emitLin $ UnOp FNeg tx + BNot -> zero linearizeBinOp :: Emits o => BinOp -> SAtom i -> SAtom i -> LinM i o SAtom SAtom -linearizeBinOp op x' y' = do - WithTangent x tx <- linearizeAtom x' - WithTangent y ty <- linearizeAtom y' - let emitZeroT = withZeroT $ emitOp $ BinOp op x y - case op of - IAdd -> emitZeroT - ISub -> emitZeroT - IMul -> emitZeroT - IDiv -> emitZeroT - IRem -> emitZeroT - ICmp _ -> emitZeroT - FAdd -> withT (add x y) (bindM2 add tx ty) - FSub -> withT (sub x y) (bindM2 sub tx ty) - FMul -> withT (mul x y) - (bindM2 add (bindM2 mul (referToPrimal x) ty) - (bindM2 mul tx (referToPrimal y))) - FDiv -> withT (div' x y) do - tx' <- bindM2 div' tx (referToPrimal y) - ty' <- bindM2 div' (bindM2 mul (referToPrimal x) ty) - (bindM2 mul (referToPrimal y) (referToPrimal y)) - sub tx' ty' - FPow -> withT (emitOp $ BinOp FPow x y) do - px <- referToPrimal x - py <- referToPrimal y - c <- (1.0 `fLitLike` py) >>= (sub py) >>= fpow px - tx' <- bindM2 mul tx (return py) - ty' <- bindM2 mul (bindM2 mul (return px) ty) (flog px) - mul c =<< add tx' ty' - FCmp _ -> emitZeroT - BAnd -> emitZeroT - BOr -> emitZeroT - BXor -> emitZeroT - BShL -> emitZeroT - BShR -> emitZeroT +linearizeBinOp op x'' y'' = do + WithTangent x' tx' <- linearizeAtom x'' + WithTangent y' ty' <- linearizeAtom y'' + ans' <- emit $ BinOp op x' y' + return $ WithTangent ans' do + ans <- sinkM ans' + x <- referToPrimal x' + y <- referToPrimal y' + tx <- tx' + ty <- ty' + let zero = zeroLikeT ans + case op of + IAdd -> zero + ISub -> zero + IMul -> zero + IDiv -> zero + IRem -> zero + ICmp _ -> zero + FAdd -> emitLin $ BinOp FAdd tx ty + FSub -> emitLin $ BinOp FSub tx ty + FMul -> do + t1 <- emitLin $ BinOp FMul ty x + t2 <- emitLin $ BinOp FMul tx y + emitLin $ BinOp FAdd t1 t2 + FDiv -> do + t1 <- emitLin $ BinOp FDiv tx y + xyy <- fdiv x =<< fmul y y + t2 <- emitLin $ BinOp FMul ty xyy + emitLin $ BinOp FSub t1 t2 + FPow -> do + ym1 <- fsub y (1.0 `fLitLike` y) + xpowym1 <- emit $ BinOp FPow x ym1 + xlogx <- fmul x =<< emit (UnOp Log x) + t1 <- emitLin $ BinOp FMul tx y + t2 <- emitLin $ BinOp FMul ty xlogx + t12 <- emitLin $ BinOp FAdd t1 t2 + emitLin $ BinOp FMul xpowym1 t12 + FCmp _ -> zero + BAnd -> zero + BOr -> zero + BXor -> zero + BShL -> zero + BShR -> zero -- This has the same type as `sinkM` and falls back thereto, but recomputes -- indexing a primal array in the tangent to avoid materializing intermediate @@ -566,23 +582,24 @@ linearizeBinOp op x' y' = do referToPrimal :: (Builder SimpIR m, Emits l, DExt n l) => SAtom n -> m l (SAtom l) referToPrimal x = do case x of - Var v -> lookupEnv (atomVarName $ sink v) >>= \case + Stuck _ (Var v) -> lookupEnv (atomVarName $ sink v) >>= \case AtomNameBinding (LetBound (DeclBinding PlainLet (Atom atom))) -> referToPrimal atom - AtomNameBinding (LetBound (DeclBinding PlainLet (TabApp _ tab is))) -> do + AtomNameBinding (LetBound (DeclBinding PlainLet (TabApp _ tab i))) -> do tab' <- referToPrimal tab - is' <- mapM referToPrimal is - emitExpr =<< mkTabApp tab' is' + i' <- referToPrimal i + emit =<< mkTabApp tab' i' _ -> sinkM x _ -> sinkM x linearizePrimCon :: Emits o => Con SimpIR i -> LinM i o SAtom SAtom linearizePrimCon con = case con of - Lit _ -> emitZeroT - ProdCon xs -> fmapLin (ProdVal . fromComposeE) $ seqLin (fmap linearizeAtom xs) + Lit _ -> zero + ProdCon xs -> fmapLin (Con . ProdCon . fromComposeE) $ seqLin (fmap linearizeAtom xs) SumCon _ _ _ -> notImplemented - HeapVal -> emitZeroT - where emitZeroT = withZeroT $ renameM $ Con con + HeapVal -> zero + DepPair _ _ _ -> notImplemented + where zero = emitZeroT con linearizeHof :: Emits o => Hof SimpIR i -> LinM i o SAtom SAtom linearizeHof hof = case hof of @@ -590,22 +607,22 @@ linearizeHof hof = case hof of UnaryLamExpr ib body <- return lam ixTy <- renameM ixTy' (lam', Abs ib' linLam) <- withFreshBinder noHint (ixTypeType ixTy) \ib' -> do - (block', linLam) <- extendSubst (ib@>binderName ib') $ linearizeBlockDefunc body + (block', linLam) <- extendSubst (ib@>binderName ib') $ linearizeExprDefunc body return (UnaryLamExpr ib' block', Abs ib' linLam) primalsAux <- emitHof $ For d ixTy lam' case linLam of TrivialRecon linLam' -> return $ WithTangent primalsAux do Abs ib'' linLam'' <- sinkM (Abs ib' linLam') - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@>Rename (atomVarName i')) $ applyLinLam linLam'' ReconWithData reconAbs -> do primals <- buildMap primalsAux getFst return $ WithTangent primals do Abs ib'' (Abs bs linLam') <- sinkM (Abs ib' reconAbs) - withSubstReaderT $ buildFor noHint d (sink ixTy) \i' -> do + withSubstReaderT $ emitLin =<< mkFor noHint d (sink ixTy) \i' -> do extendSubst (ib''@> Rename (atomVarName i')) do - residuals' <- tabApp (sink primalsAux) (Var i') >>= getSnd >>= unpackTelescope bs + residuals' <- tabApp (sink primalsAux) (toAtom i') >>= getSnd >>= unpackTelescope bs extendSubst (bs @@> (SubstVal <$> residuals')) $ applyLinLam linLam' RunReader r lam -> do @@ -620,7 +637,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunReader rLin' tanEffLam + emitHofLin $ RunReader rLin' tanEffLam RunState Nothing sInit lam -> do WithTangent sInit' sLin <- linearizeAtom sInit (lam', recon) <- linearizeEffectFun State lam @@ -633,7 +650,7 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunState Nothing sLin' tanEffLam + emitHofLin $ RunState Nothing sLin' tanEffLam RunWriter Nothing bm lam -> do -- TODO: check it's actually the 0/+ monoid (or should we just build that in?) bm' <- renameM bm @@ -647,9 +664,9 @@ linearizeHof hof = case hof of tanEffLam <- buildEffLam noHint tt \h ref -> extendTangentArgss [h, ref] do withSubstReaderT $ applyLinLam $ sink linLam - emitHof $ RunWriter Nothing bm'' tanEffLam + emitHofLin $ RunWriter Nothing bm'' tanEffLam RunIO body -> do - (body', recon) <- linearizeBlockDefunc body + (body', recon) <- linearizeExprDefunc body primalAux <- emitHof $ RunIO body' (primal, linLam) <- reconstruct primalAux recon return $ WithTangent primal do @@ -658,15 +675,15 @@ linearizeHof hof = case hof of linearizeEffectFun :: RWS -> SLam i -> PrimalM i o (SLam o, LinLamAbs o) linearizeEffectFun rws (BinaryLamExpr hB refB body) = do - withFreshBinder noHint (TC HeapType) \h -> do + withFreshBinder noHint (TyCon HeapType) \h -> do bTy <- extendSubst (hB@>binderName h) $ renameM $ binderType refB withFreshBinder noHint bTy \b -> do let ref = binderVar b hVar <- sinkM $ binderVar h (body', linLam) <- extendActiveSubst hB hVar $ extendActiveSubst refB ref $ -- TODO: maybe we should check whether we need to extend the active effects - extendActiveEffs (RWSEffect rws (Var hVar)) do - linearizeBlockDefunc body + extendActiveEffs (RWSEffect rws (toAtom hVar)) do + linearizeExprDefunc body -- TODO: this assumes that references aren't returned. Our type system -- ensures that such references can never be *used* once the effect runner -- returns, but technically it's legal to return them. @@ -674,21 +691,6 @@ linearizeEffectFun rws (BinaryLamExpr hB refB body) = do return (BinaryLamExpr h b body', linLam') linearizeEffectFun _ _ = error "expect effect function to be a binary lambda" -withT :: PrimalM i o (e1 o) - -> (forall o'. (Emits o', DExt o o') => TangentM o' (e2 o')) - -> PrimalM i o (WithTangent o e1 e2) -withT p t = do - p' <- p - return $ WithTangent p' t - -withZeroT :: PrimalM i o (Atom SimpIR o) - -> PrimalM i o (WithTangent o SAtom SAtom) -withZeroT p = do - p' <- p - return $ WithTangent p' do - pTy <- return $ getType $ sink p' - zeroAt =<< tangentType pTy - notImplemented :: HasCallStack => a notImplemented = error "Not implemented" diff --git a/src/lib/Live/Eval.hs b/src/lib/Live/Eval.hs index 779d1a5ff..97f99761c 100644 --- a/src/lib/Live/Eval.hs +++ b/src/lib/Live/Eval.hs @@ -1,458 +1,335 @@ --- Copyright 2019 Google LLC +-- Copyright 2023 Google LLC -- -- Use of this source code is governed by a BSD-style -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -module Live.Eval (RFragment (..), SetVal(..), watchAndEvalFile) where +{-# LANGUAGE UndecidableInstances #-} -import Control.Concurrent (forkIO, killThread, readChan, threadDelay, ThreadId) -import Control.Monad.Reader +module Live.Eval ( + watchAndEvalFile, EvalServer, EvalUpdate, CellsState, CellsUpdate, fmapCellsUpdate, + NodeList (..), NodeListUpdate (..), subscribeIO, nodeListAsUpdate) where + +import Control.Concurrent +import Control.Monad import Control.Monad.State.Strict -import Data.ByteString qualified as BS +import Control.Monad.Writer.Strict +import qualified Data.Map.Strict as M +import Data.Aeson (ToJSON) +import Data.Functor ((<&>)) +import Data.Maybe (fromJust) import Data.Text (Text) -import Data.Text.Encoding qualified as T -import Data.Map.Strict qualified as M - -import Data.Aeson (ToJSON, toJSON, (.=)) -import Data.Aeson qualified as A -import Data.Text.Prettyprint.Doc -import System.Directory (getModificationTime) +import Prelude hiding (span) +import GHC.Generics -import ConcreteSyntax import Actor -import RenderHtml (ToMarkup, pprintHtml) -import TopLevel -import Types.Misc +import IncState import Types.Source -import Util (onFst, onSnd) +import TopLevel +import ConcreteSyntax +import MonadUtil -type NodeId = Int -data WithId a = WithId { getNodeId :: NodeId - , withoutId :: a } - deriving Show +-- === Top-level interface === -data RFragment = RFragment (SetVal [NodeId]) - (M.Map NodeId SourceBlock) - (M.Map NodeId Result) +type EvalServer = StateServer EvalState EvalUpdate +type EvalState = CellsState SourceBlock Outputs +type EvalUpdate = CellsUpdate SourceBlock Outputs --- Start watching and evaluating the given file. Returns a channel on --- which one can subscribe to updates to the evaluation state. --- --- The overall system looks like this: --- - `forkWatchFile` creates an actor that watches the file for --- changes and sends `FileChanged` messages to the driver. --- - `runDriver` creates the main driver actor, which manages --- the evaluation state and produces rendering fragments. --- - `logServer` creates an actor that accumulates rendering fragments --- from the driver and broadcasts them to any subscribed clients. --- --- `FileChanged` messages from the watch file actor may invalidate the --- current state. The driver delegates the actual evaluation to a --- sub-thread so it can remain responsive. --- -- `watchAndEvalFile` returns the channel by which a client may -- subscribe by sending a write-only view of its input channel. -watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx - -> IO (PChan (PChan RFragment)) +watchAndEvalFile :: FilePath -> EvalConfig -> TopStateEx -> IO EvalServer watchAndEvalFile fname opts env = do - (_, resultsChan) <- spawn logServer - let cfg = (opts, subChan Publish resultsChan) - (_, driverChan) <- spawn $ runDriver cfg env - forkWatchFile fname $ subChan FileChanged driverChan - return $ subChan Subscribe resultsChan - --- === executing blocks concurrently === - -type SourceContents = Text - -type DriverCfg = (EvalConfig, PChan RFragment) - --- The evaluation-in-progress state is --- - The (identified) current top-level environment --- - If a worker is currently running, its ThreadId and the --- SourceBlock it it working on (necessarily in the current --- top-level environment) --- - The list of blocks that remain to be evaluated (if any) after --- the current worker completes. If nonempty, there should be --- a current worker. --- This is consistent at entry and exit from handling each message, --- but may be briefly inconsistent while a message is being handled. -type WorkerSpec = Maybe (ThreadId, WithId SourceBlock) -data SourceEvalState = SourceEvalState - (WithId TopStateEx) WorkerSpec [WithId SourceBlock] - -initialEvalState :: TopStateEx -> SourceEvalState -initialEvalState env = (SourceEvalState (WithId 0 env) Nothing []) - -newtype DriverM a = DriverM - { drive :: (ReaderT DriverCfg - (ReaderT (PChan DriverEvent) - (StateT (SourceEvalState, CacheState) IO)) a) - } - deriving (Functor, Applicative, Monad, MonadIO) - -type EvalCache = M.Map (SourceBlock, WithId TopStateEx) (NodeId, WithId TopStateEx) -data CacheState = CacheState - { nextBlockId :: NodeId - , nextStateId :: NodeId - , evalCache :: EvalCache } - -emptyCache :: CacheState -emptyCache = CacheState 0 1 mempty - -class (Monad m, MonadIO m) => Driver m where - askOptions :: m EvalConfig - askResultsOutput :: m (PChan RFragment) - askSelf :: m (PChan DriverEvent) - getTopState :: m (WithId TopStateEx) - putTopState :: WithId TopStateEx -> m () - -- Resets the evaluation state to initial, from the given TopStateEx. - -- Returns the old top state and the old worker spec, for reuse - refresh :: TopStateEx -> m (WithId TopStateEx, WorkerSpec) - -- Get the work chunk we are waiting for, if any - getWorkingBlock :: m (Maybe (WithId SourceBlock)) - -- Run the action if there is no worker, otherwise don't - whenNoWorker :: m () -> m () - putWorker :: WorkerSpec -> m () - -- If a block is pending, remove it from the queue and run the - -- action on it, otherwise don't. - popPending :: (WithId SourceBlock -> m ()) -> m () - putPending :: [WithId SourceBlock] -> m () - newBlockId :: m Int - newStateId :: m Int - lookupCache :: SourceBlock - -> WithId TopStateEx - -> m (Maybe (NodeId, WithId TopStateEx)) - insertCache :: SourceBlock - -> WithId TopStateEx - -> (NodeId, WithId TopStateEx) - -> m () - --- The externally visible behavior of the main driver loop: --- - When the source file changes, send the new set of visible node IDs --- (`updateResultList`) to the `PChan RFragment` --- - When a new source block is discovered, assign an ID to it and send --- the association of that block with that ID (`makeNewBlockId`) --- - When a source block is successfully evaluated, associate the result --- with its ID and send that (inside `evalBlock`) - --- Internally, we implement this behavior with a driver thread that --- forks a worker thread. Why two threads? So the driver can notice --- if a source block in progress has disappeared from the file and --- kill the worker when that happens. - --- The worker communicates with the driver by sending a "work --- complete" message. Note that a worker due to be killed may send a --- "work complete" message before the driver actually kills it. If a --- "file changed" message arrived in the interim, the TopState the --- worker delivers remains valid to enter into the cache, but should --- not change the driver's then-current TopState. - --- For this reason, the WorkComplete message contains the ids of the --- TopStateEx and SourceBlock that the woker evaluated. - -data DriverEvent = FileChanged SourceContents - | WorkComplete (WithId TopStateEx) (WithId SourceBlock) (Result, TopStateEx) - -runDriver :: DriverCfg -> TopStateEx -> Actor DriverEvent -runDriver cfg env self = do - liftM fst - $ flip runStateT (initialEvalState env, emptyCache) - $ flip runReaderT (sendOnly self) - $ flip runReaderT cfg - $ drive $ forever $ do - msg <- liftIO $ readChan self - case msg of - (FileChanged source) -> evalSource env source - (WorkComplete block topState payload) -> processWork block topState payload - --- Start evaluation of the (updated) source file in the given (fresh) --- evaluation state. The evaluation state carried in the monad is --- still the state as of the end of the previous message. -evalSource :: Driver m => TopStateEx -> SourceContents -> m () -evalSource env source = do - -- Save the old state from the monad, because we need to kill or - -- reuse the worker from it. - (oldTopState, oldWorker) <- refresh env - let UModule _ _ blocks = parseUModule Main source - (evaluated, remaining) <- tryEvalBlocksCached blocks - (reused, remaining') <- tryReuseWorker oldTopState oldWorker remaining - remaining'' <- mapM makeNewBlockId remaining' - updateResultList $ map getNodeId $ evaluated ++ reused ++ remaining'' - putPending $ reused ++ remaining'' - maybeLaunchWorker - --- See which blocks already have completed values and reuse those. -tryEvalBlocksCached :: Driver m - => [SourceBlock] - -> m ([WithId SourceBlock], [SourceBlock]) -tryEvalBlocksCached [] = return ([], []) -tryEvalBlocksCached blocks@(block:rest) = do - env <- getTopState - res <- lookupCache block env - case res of - Nothing -> return ([], blocks) - Just (blockId, env') -> do - let block' = WithId blockId block - putTopState env' - (evaluated, remaining) <- tryEvalBlocksCached rest - return (block':evaluated, remaining) - --- See whether the formerly active worker (if any) is still doing --- something useful given the list of blocks we are waiting to finish; --- if so reuse it, and if not kill it. -tryReuseWorker :: Driver m - => WithId TopStateEx - -> WorkerSpec - -> [SourceBlock] - -> m ([WithId SourceBlock], [SourceBlock]) -tryReuseWorker _ w [] = - liftIO (forM_ w (killThread . fst)) >> return ([], []) -tryReuseWorker _ Nothing blocks = - return ([], blocks) -tryReuseWorker oldEnv w@(Just (_, oldNext)) (next:rest) = do - curEnv <- getTopState - if (curEnv == oldEnv) && (withoutId oldNext == next) then do - -- Reuse the worker - putWorker w - return ([oldNext], rest) - else - liftIO (forM_ w (killThread . fst)) >> return ([], next:rest) - -processWork :: Driver m - => WithId TopStateEx - -> WithId SourceBlock - -> (Result, TopStateEx) - -> m () -processWork oldState block answer = do - -- The computed result is true regardless of whether this is the - -- worker we are waiting for or not, and therefore safe to cache - -- outside the `when` clause. There is a narrow benefit here: if a - -- worker completes normally while we're processing a FileChanged - -- message, it can send a sound WorkComplete message before we - -- actually kill it. We record that result in case the user edits - -- back to a state where it can be shown. - newState <- recordTruth oldState block answer - curState <- getTopState - waitingFor <- getWorkingBlock - when (oldState == curState - && (fmap withoutId waitingFor == Just (withoutId block))) $ do - -- We only update our working state if this message is, in fact, - -- from the worker we are currently waiting for. - rotateWorkingState newState - --- Record what the worker computed in our cache of truths, and return --- the updated environment. This is sound regardless of whether we --- are waiting for this evaluation or not. -recordTruth :: Driver m - => WithId TopStateEx - -> WithId SourceBlock - -> (Result, TopStateEx) - -> m (WithId TopStateEx) -recordTruth oldState (WithId blockId block) (result, s) = do - resultsChan <- askResultsOutput - liftIO $ resultsChan `sendPChan` oneResult blockId result - newState <- makeNewStateId s - insertCache block oldState (blockId, newState) - return newState - --- Update our current evaluation state assuming the work we were --- waiting for was just completed with the given new evaluation --- environment. -rotateWorkingState :: Driver m => WithId TopStateEx -> m () -rotateWorkingState newState = do - putTopState newState - putWorker Nothing -- Worker finished - maybeLaunchWorker - --- === DriverM utils === - --- If we have work to do but no worker doing it, launch such a worker. -maybeLaunchWorker :: (Driver m) => m () -maybeLaunchWorker = do - whenNoWorker $ popPending \next -> do - curState <- getTopState - opts <- askOptions - self <- askSelf - tid <- liftIO $ forkWorker opts curState next self - putWorker $ Just (tid, next) - -forkWorker :: EvalConfig -> WithId TopStateEx -> WithId SourceBlock - -> PChan DriverEvent -> IO ThreadId -forkWorker opts curState block chan = forkIO $ do - result <- evalSourceBlockIO opts (withoutId curState) (withoutId block) - chan `sendPChan` (WorkComplete curState block result) - -makeNewBlockId :: Driver m => SourceBlock -> m (WithId SourceBlock) -makeNewBlockId block = do - newId <- newBlockId - resultsChan <- askResultsOutput - liftIO $ resultsChan `sendPChan` oneSourceBlock newId block - return $ WithId newId block - -makeNewStateId :: Driver m => TopStateEx -> m (WithId TopStateEx) -makeNewStateId env = do - newId <- newStateId - return $ WithId newId env - --- === utils for sending results === - -updateResultList :: Driver m => [NodeId] -> m () -updateResultList ids = do - resultChan <- askResultsOutput - liftIO $ resultChan `sendPChan` RFragment (Set ids) mempty mempty - -oneResult :: NodeId -> Result -> RFragment -oneResult k r = RFragment mempty mempty (M.singleton k r) - -oneSourceBlock :: NodeId -> SourceBlock -> RFragment -oneSourceBlock k b = RFragment mempty (M.singleton k b) mempty - --- === watching files === - --- A non-Actor source. Sends file contents to channel whenever file --- is modified. -forkWatchFile :: FilePath -> PChan Text -> IO () -forkWatchFile fname chan = onmod fname $ sendFileContents fname chan - -sendFileContents :: String -> PChan Text -> IO () -sendFileContents fname chan = do - putStrLn $ fname ++ " updated" - s <- T.decodeUtf8 <$> BS.readFile fname - sendPChan chan s - -onmod :: FilePath -> IO () -> IO () -onmod fname action = do - action - t <- getModificationTime fname - void $ forkIO $ loop t - where - loop t = do - t' <- getModificationTime fname - threadDelay 100000 - unless (t == t') action - loop t' - --- === instances === - -instance Driver DriverM where - askOptions = DriverM $ asks fst - askResultsOutput = DriverM $ asks snd - askSelf = DriverM $ lift $ ask - getTopState = DriverM $ do - (SourceEvalState s _ _) <- gets fst - return s - - putTopState s = DriverM $ modify $ onFst \(SourceEvalState _ w blocks) - -> (SourceEvalState s w blocks) - - refresh env = DriverM $ do - (SourceEvalState oldState oldWorker _) <- gets fst - modify $ onFst $ const $ initialEvalState env - return (oldState, oldWorker) - - getWorkingBlock = DriverM $ do - (SourceEvalState _ w _) <- gets fst - return $ (fmap snd) w - - whenNoWorker (DriverM action) = DriverM $ do - (SourceEvalState _ w _) <- gets fst - case w of - (Just _) -> return () - Nothing -> action - - putWorker w = DriverM $ modify $ onFst \(SourceEvalState s _ blocks) - -> (SourceEvalState s w blocks) - - popPending action = do - (SourceEvalState _ _ curPending) <- DriverM $ gets fst - case curPending of - [] -> return () - (next:rest) -> do - DriverM $ modify $ onFst \(SourceEvalState s w _) - -> (SourceEvalState s w rest) - action next + watcher <- launchFileWatcher fname + parser <- launchCellParser watcher \source -> uModuleSourceBlocks $ parseUModule Main source + launchDagEvaluator parser env (sourceBlockEvalFun opts) - putPending blocks = DriverM $ modify $ onFst \(SourceEvalState s w _) - -> (SourceEvalState s w blocks) +sourceBlockEvalFun :: EvalConfig -> Mailbox Outputs -> TopStateEx -> SourceBlock -> IO TopStateEx +sourceBlockEvalFun cfg resultChan env block = do + let cfg' = cfg { cfgLogAction = send resultChan } + evalSourceBlockIO cfg' env block - lookupCache block env = DriverM $ do - cache <- gets (evalCache . snd) - return $ M.lookup (block, env) cache +fmapCellsUpdate :: CellsUpdate i o -> (NodeId -> i -> i') -> (NodeId -> o -> o') -> CellsUpdate i' o' +fmapCellsUpdate (NodeListUpdate t m) fi fo = NodeListUpdate t m' where + m' = mapUpdateMapWithKey m + (\k (CellState i s o) -> CellState (fi k i) s (fo k o)) + (\k (CellUpdate s o) -> CellUpdate s (fo k o)) - newBlockId = DriverM $ do - newId <- gets $ nextBlockId . snd - modify $ onSnd \cache -> cache {nextBlockId = newId + 1 } - return newId +-- === DAG diff state === - newStateId = DriverM $ do - newId <- gets $ nextStateId . snd - modify $ onSnd \cache -> cache {nextStateId = newId + 1 } - return newId +-- We intend to make this an arbitrary Dag at some point but for now we just +-- assume that dependence is just given by the top-to-bottom ordering of blocks +-- within the file. - insertCache block env val = DriverM $ modify $ onSnd \cache -> - cache { evalCache = M.insert (block, env) val $ evalCache cache } - -instance Semigroup RFragment where - (RFragment x y z) <> (RFragment x' y' z') = RFragment (x<>x') (y<>y') (z<>z') - -instance Monoid RFragment where - mempty = RFragment mempty mempty mempty - -instance Eq (WithId a) where - (==) (WithId x _) (WithId y _) = x == y - -instance Ord (WithId a) where - compare (WithId x _) (WithId y _) = compare x y - -instance ToJSON a => ToJSON (SetVal a) where - toJSON (Set x) = A.object ["val" .= toJSON x] - toJSON NotSet = A.Null - -instance (ToJSON k, ToJSON v) => ToJSON (MonMap k v) where - toJSON (MonMap m) = toJSON (M.toList m) - -instance ToJSON RFragment where - toJSON (RFragment ids blocks results) = toJSON (ids, contents) - where contents = MonMap (M.map toHtmlFragment blocks) - <> MonMap (M.map toHtmlFragment results) - -type TreeAddress = [Int] -type HtmlFragment = [(TreeAddress, String)] - -toHtmlFragment :: ToMarkup a => a -> HtmlFragment -toHtmlFragment x = [([], pprintHtml x)] - -instance Pretty SourceEvalState where - pretty (SourceEvalState env worker pending) = - "In env ID" <+> pretty (getNodeId env) <> line - <> "waiting for" <+> pretty (show worker) <+> "to evaluate" <> line - <> pretty (map prettify pending) where - prettify (WithId blockId block) = (blockId, block) - -instance Pretty DriverEvent where - pretty (FileChanged contents) = "New file contents" <> line <> pretty contents - pretty (WorkComplete env (WithId blockId block) (result, _)) = - "Finished evaluating" <+> pretty (blockId, block) - <+> "in env with ID" <+> pretty (getNodeId env) - <+> "got" <+> pretty result - --- === some handy monoids === - -data SetVal a = Set a | NotSet - -instance Semigroup (SetVal a) where - x <> NotSet = x - _ <> Set x = Set x - -instance Monoid (SetVal a) where - mempty = NotSet +type NodeId = Int -newtype MonMap k v = MonMap (M.Map k v) deriving (Show, Eq) +data NodeList a = NodeList + { orderedNodes :: [NodeId] + , nodeMap :: M.Map NodeId a } + deriving (Show, Generic) + +data NodeListUpdate s d = NodeListUpdate + { orderedNodesUpdate :: TailUpdate NodeId + , nodeMapUpdate :: MapUpdate NodeId s d } + deriving (Show, Generic) + +instance IncState s d => Semigroup (NodeListUpdate s d) where + NodeListUpdate x1 y1 <> NodeListUpdate x2 y2 = NodeListUpdate (x1<>x2) (y1<>y2) + +instance IncState s d => Monoid (NodeListUpdate s d) where + mempty = NodeListUpdate mempty mempty + +instance IncState s d => IncState (NodeList s) (NodeListUpdate s d) where + applyDiff (NodeList m xs) (NodeListUpdate dm dxs) = + NodeList (applyDiff m dm) (applyDiff xs dxs) + +type Dag a = NodeList (Unchanging a) +type DagUpdate a = NodeListUpdate (Unchanging a) () + +nodeListAsUpdate :: NodeList s -> NodeListUpdate s d +nodeListAsUpdate (NodeList xs m)= NodeListUpdate (TailUpdate 0 xs) (MapUpdate $ fmap Create m) + +emptyNodeList :: NodeList a +emptyNodeList = NodeList [] mempty + +buildNodeList :: FreshNames NodeId m => [a] -> m (NodeList a) +buildNodeList vals = do + nodeList <- forM vals \val -> do + nodeId <- freshName + return (nodeId, val) + return $ NodeList (fst <$> nodeList) (M.fromList nodeList) + +commonPrefixLength :: Eq a => [a] -> [a] -> Int +commonPrefixLength (x:xs) (y:ys) | x == y = 1 + commonPrefixLength xs ys +commonPrefixLength _ _ = 0 + +nodeListVals :: NodeList a -> [a] +nodeListVals nodes = orderedNodes nodes <&> \k -> fromJust $ M.lookup k (nodeMap nodes) + +computeNodeListUpdate :: (Eq s, FreshNames NodeId m) => NodeList s -> [s] -> m (NodeListUpdate s d) +computeNodeListUpdate nodes newVals = do + let prefixLength = commonPrefixLength (nodeListVals nodes) newVals + let oldTail = drop prefixLength $ orderedNodes nodes + NodeList newTail nodesCreated <- buildNodeList $ drop prefixLength newVals + let nodeUpdates = fmap Create nodesCreated <> M.fromList (fmap (,Delete) oldTail) + return $ NodeListUpdate (TailUpdate (length oldTail) newTail) (MapUpdate nodeUpdates) + +-- === Cell parser === + +-- This coarsely parses the full file into blocks and forms a DAG (for now a +-- trivial one assuming all top-to-bottom dependencies) of the results. + +type CellParser a = StateServer (Dag a) (DagUpdate a) + +data CellParserMsg a = + Subscribe_CP (SubscribeMsg (Dag a) (DagUpdate a)) + | Update_CP (Overwrite Text) + deriving (Show) + +launchCellParser :: (Eq a, MonadIO m) => FileWatcher -> (Text -> [a]) -> m (CellParser a) +launchCellParser fileWatcher parseCells = + sliceMailbox Subscribe_CP <$> launchActor (cellParserImpl fileWatcher parseCells) + +cellParserImpl :: Eq a => FileWatcher -> (Text -> [a]) -> ActorM (CellParserMsg a) () +cellParserImpl fileWatcher parseCells = runFreshNameT do + Overwritable initContents <- subscribe Update_CP fileWatcher + initNodeList <- buildNodeList $ fmap Unchanging $ parseCells initContents + runIncServerT initNodeList $ messageLoop \case + Subscribe_CP msg -> handleSubscribeMsg msg + Update_CP NoChange -> return () + Update_CP (OverwriteWith newContents) -> do + let newCells = fmap Unchanging $ parseCells newContents + curNodeList <- getl It + update =<< computeNodeListUpdate curNodeList newCells + flushDiffs + +-- === Dag evaluator === + +-- This is where we track the state of evaluation and decide what we needs to be +-- run and what needs to be killed. + +type Evaluator i o = StateServer (CellsState i o) (CellsUpdate i o) +newtype EvaluatorM s i o a = + EvaluatorM { runEvaluatorM' :: + IncServerT (CellsState i o) (CellsUpdate i o) + (StateT (EvaluatorState s i o) + (ActorM (EvaluatorMsg s i o))) a } + deriving (Functor, Applicative, Monad, MonadIO, + Actor (EvaluatorMsg s i o)) +deriving instance Monoid o => IncServer (CellsState i o) (CellsUpdate i o) (EvaluatorM s i o) + +instance Monoid o => Semigroup (CellUpdate o) where + CellUpdate s o <> CellUpdate s' o' = CellUpdate (s<>s') (o<>o') + +instance Monoid o => Monoid (CellUpdate o) where + mempty = CellUpdate mempty mempty + +instance Monoid o => IncState (CellState i o) (CellUpdate o) where + applyDiff (CellState source status result) (CellUpdate status' result') = + CellState source (fromOverwritable (applyDiff (Overwritable status) status')) (result <> result') + +instance Monoid o => DefuncState (EvaluatorMUpdate s i o) (EvaluatorM s i o) where + update = \case + UpdateDagEU dag -> EvaluatorM $ update dag + UpdateCurJob status -> EvaluatorM $ lift $ modify \s -> s { curRunningJob = status } + UpdateEnvs envs -> EvaluatorM $ lift $ modify \s -> s { prevEnvs = envs} + AppendEnv env -> do + envs <- getl PrevEnvs + update $ UpdateEnvs $ envs ++ [env] + UpdateCellState nodeId cellUpdate -> update $ UpdateDagEU $ NodeListUpdate mempty $ + MapUpdate $ M.singleton nodeId $ Update cellUpdate + +instance Monoid o => LabelReader (EvaluatorMLabel s i o) (EvaluatorM s i o) where + getl l = case l of + NodeListEM -> EvaluatorM $ orderedNodes <$> getl It + NodeInfo nodeId -> EvaluatorM $ M.lookup nodeId <$> nodeMap <$> getl It + PrevEnvs -> EvaluatorM $ lift $ prevEnvs <$> get + CurRunningJob -> EvaluatorM $ lift $ curRunningJob <$> get + EvalFun -> EvaluatorM $ lift $ evalFun <$> get + +data EvaluatorMUpdate s i o = + UpdateDagEU (NodeListUpdate (CellState i o) (CellUpdate o)) + | UpdateCellState NodeId (CellUpdate o) + | UpdateCurJob CurJobStatus + | UpdateEnvs [s] + | AppendEnv s + +data EvaluatorMLabel s i o a where + NodeListEM :: EvaluatorMLabel s i o [NodeId] + NodeInfo :: NodeId -> EvaluatorMLabel s i o (Maybe (CellState i o)) + PrevEnvs :: EvaluatorMLabel s i o [s] + CurRunningJob :: EvaluatorMLabel s i o (CurJobStatus) + EvalFun :: EvaluatorMLabel s i o (EvalFun s i o) + +-- `s` is the persistent state (i.e. TopEnvEx the environment) +-- `i` is the type of input cell (e.g. SourceBlock) +-- `o` is the (monoidal) type of updates, e.g. `Result` +type EvalFun s i o = Mailbox o -> s -> i -> IO s +-- It's redundant to have both NodeId and TheadId but it defends against +-- possible GHC reuse of ThreadId (I don't know if that can actually happen) +type JobId = (ThreadId, NodeId) +type CurJobStatus = Maybe (JobId, CellIndex) + +data EvaluatorState s i o = EvaluatorState + { prevEnvs :: [s] + , evalFun :: EvalFun s i o + , curRunningJob :: CurJobStatus } + +data CellStatus = Waiting | Running | Complete deriving (Show, Generic) + +data CellState i o = CellState i CellStatus o deriving (Show, Generic) +data CellUpdate o = CellUpdate (Overwrite CellStatus) o deriving (Show, Generic) + +type Show3 s i o = (Show s, Show i, Show o) + +type CellsState i o = NodeList (CellState i o) +type CellsUpdate i o = NodeListUpdate (CellState i o) (CellUpdate o) + +type CellIndex = Int -- index in the list of cells, not the NodeId + +data JobUpdate o s = PartialJobUpdate o | JobComplete s deriving (Show) + +data EvaluatorMsg s i o = + SourceUpdate (DagUpdate i) + | JobUpdate JobId (JobUpdate o s) + | Subscribe_E (SubscribeMsg (CellsState i o) (CellsUpdate i o)) + deriving (Show) + +initEvaluatorState :: s -> EvalFun s i o -> EvaluatorState s i o +initEvaluatorState s evalCell = EvaluatorState [s] evalCell Nothing + +launchDagEvaluator :: (Show3 s i o, Monoid o, MonadIO m) => CellParser i -> s -> EvalFun s i o -> m (Evaluator i o) +launchDagEvaluator cellParser env evalCell = do + mailbox <- launchActor do + let s = initEvaluatorState env evalCell + void $ flip runStateT s $ runIncServerT emptyNodeList $ runEvaluatorM' $ + dagEvaluatorImpl cellParser + return $ sliceMailbox Subscribe_E mailbox + +dagEvaluatorImpl :: (Show3 s i o, Monoid o) => CellParser i -> EvaluatorM s i o () +dagEvaluatorImpl cellParser = do + initDag <- subscribe SourceUpdate cellParser + processDagUpdate (nodeListAsUpdate initDag) >> flushDiffs + launchNextJob + messageLoop \case + Subscribe_E msg -> handleSubscribeMsg msg + SourceUpdate dagUpdate -> do + processDagUpdate dagUpdate + flushDiffs + JobUpdate jobId jobUpdate -> do + processJobUpdate jobId jobUpdate + flushDiffs + +processJobUpdate :: (Show3 s i o, Monoid o) => JobId -> JobUpdate o s -> EvaluatorM s i o () +processJobUpdate jobId jobUpdate = do + getl CurRunningJob >>= \case + Just (jobId', _) -> when (jobId == jobId') do + let nodeId = snd jobId + case jobUpdate of + JobComplete newEnv -> do + update $ UpdateCellState nodeId $ CellUpdate (OverwriteWith Complete) mempty + update $ UpdateCurJob Nothing + update $ AppendEnv newEnv + launchNextJob + flushDiffs + PartialJobUpdate result -> update $ UpdateCellState nodeId $ CellUpdate NoChange result + Nothing -> return () -- this job is a zombie + +nextCellIndex :: Monoid o => EvaluatorM s i o Int +nextCellIndex = do + envs <- getl PrevEnvs + return $ length envs - 1 + +launchNextJob :: (Show3 s i o, Monoid o) => EvaluatorM s i o () +launchNextJob = do + cellIndex <- nextCellIndex + nodeList <- getl NodeListEM + when (cellIndex < length nodeList) do -- otherwise we're all done + curEnv <- (!! cellIndex) <$> getl PrevEnvs + let nodeId = nodeList !! cellIndex + launchJob cellIndex nodeId curEnv + +launchJob :: (Show3 s i o, Monoid o) => CellIndex -> NodeId -> s -> EvaluatorM s i o () +launchJob cellIndex nodeId env = do + jobAction <- getl EvalFun + CellState source _ _ <- fromJust <$> getl (NodeInfo nodeId) + mailbox <- selfMailbox id + update $ UpdateCellState nodeId $ CellUpdate (OverwriteWith Running) mempty + threadId <- liftIO $ forkIO do + threadId <- myThreadId + let jobId = (threadId, nodeId) + let resultsMailbox = sliceMailbox (JobUpdate jobId . PartialJobUpdate) mailbox + finalEnv <- jobAction resultsMailbox env source + send mailbox $ JobUpdate jobId $ JobComplete finalEnv + let jobId = (threadId, nodeId) + update $ UpdateCurJob (Just (jobId, cellIndex)) + +computeNumValidCells :: Monoid o => TailUpdate NodeId -> EvaluatorM s i o Int +computeNumValidCells tailUpdate = do + let nDropped = numDropped tailUpdate + nTotal <- length <$> getl NodeListEM + return $ nTotal - nDropped + +processDagUpdate :: (Show3 s i o, Monoid o) => DagUpdate i -> EvaluatorM s i o () +processDagUpdate (NodeListUpdate tailUpdate mapUpdate) = do + nValid <- computeNumValidCells tailUpdate + envs <- getl PrevEnvs + update $ UpdateEnvs $ take (nValid + 1) envs + update $ UpdateDagEU $ NodeListUpdate tailUpdate $ mapUpdateMapWithKey mapUpdate + (\_ (Unchanging i) -> CellState i Waiting mempty) + (\_ () -> mempty) + getl CurRunningJob >>= \case + Nothing -> launchNextJob + Just ((threadId, _), cellIndex) + | (cellIndex >= nValid) -> do + -- Current job is no longer valid. Kill it and restart. + liftIO $ killThread threadId + update $ UpdateCurJob Nothing + launchNextJob + | otherwise -> return () -- Current job is fine. Let it continue. -instance (Ord k, Semigroup v) => Semigroup (MonMap k v) where - MonMap m <> MonMap m' = MonMap $ M.unionWith (<>) m m' +-- === instances === -instance (Ord k, Semigroup v) => Monoid (MonMap k v) where - mempty = MonMap mempty +instance ToJSON CellStatus +instance (ToJSON i, ToJSON o) => ToJSON (CellState i o) +instance ToJSON o => ToJSON (CellUpdate o) +instance (ToJSON s, ToJSON d) => ToJSON (NodeListUpdate s d) diff --git a/src/lib/Live/Terminal.hs b/src/lib/Live/Terminal.hs deleted file mode 100644 index c995ea64f..000000000 --- a/src/lib/Live/Terminal.hs +++ /dev/null @@ -1,82 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Live.Terminal (runTerminal) where - -import Control.Concurrent (Chan, readChan, forkIO) -import Control.Monad.State.Strict -import Data.Foldable (fold) -import qualified Data.Map.Strict as M - -import System.Console.ANSI (clearScreen, setCursorPosition) -import System.IO (BufferMode (..), hSetBuffering, stdin) - -import Actor -import Cat -import Live.Eval -import PPrint (printLitBlock) -import TopLevel - -runTerminal :: FilePath -> EvalConfig -> TopStateEx -> IO () -runTerminal fname opts env = do - resultsChan <- watchAndEvalFile fname opts env - displayResultsTerm resultsChan - -type DisplayPos = Int -data KeyboardCommand = ScrollUp | ScrollDown | ResetDisplay - -type TermDisplayM = StateT DisplayPos (CatT RFragment IO) - -displayResultsTerm :: PChan (PChan RFragment) -> IO () -displayResultsTerm resultsSubscribe = - runActor \self -> do - resultsSubscribe `sendPChan` subChan Left (sendOnly self) - void $ forkIO $ monitorKeyboard $ subChan Right (sendOnly self) - evalCatT $ flip evalStateT 0 $ forever $ termDisplayLoop self - -termDisplayLoop :: (Chan (Either RFragment KeyboardCommand)) -> TermDisplayM () -termDisplayLoop self = do - req <- liftIO $ readChan self - case req of - Left result -> extend result - Right command -> case command of - ScrollUp -> modify (+ 4) - ScrollDown -> modify (\p -> max 0 (p - 4)) - ResetDisplay -> put 0 - results <- look - pos <- get - case renderResults results of - Nothing -> return () - Just s -> liftIO $ do - let cropped = cropTrailingLines pos s - setCursorPosition 0 0 - clearScreen -- TODO: clean line-by-line instead - putStr cropped - -cropTrailingLines :: Int -> String -> String -cropTrailingLines n s = unlines $ reverse $ drop n $ reverse $ lines s - --- TODO: show incremental results -renderResults :: RFragment -> Maybe String -renderResults (RFragment NotSet _ _) = Nothing -renderResults (RFragment (Set ids) blocks results) = - liftM fold $ forM ids $ \i -> do - b <- M.lookup i blocks - r <- M.lookup i results - return $ printLitBlock True b r - --- A non-Actor source. Sends keyboard command signals as they occur. -monitorKeyboard :: PChan KeyboardCommand -> IO () -monitorKeyboard chan = do - hSetBuffering stdin NoBuffering - forever $ do - c <- getChar - case c of - 'k' -> chan `sendPChan` ScrollUp - 'j' -> chan `sendPChan` ScrollDown - 'q' -> chan `sendPChan` ResetDisplay - _ -> return () - diff --git a/src/lib/Live/Web.hs b/src/lib/Live/Web.hs index ad7715599..0f5739a8e 100644 --- a/src/lib/Live/Web.hs +++ b/src/lib/Live/Web.hs @@ -14,22 +14,26 @@ import Network.Wai (Application, StreamingBody, pathInfo, import Network.Wai.Handler.Warp (run) import Network.HTTP.Types (status200, status404) import Data.Aeson (ToJSON, encode) -import Data.Binary.Builder (fromByteString, Builder) +import Data.Binary.Builder (fromByteString) import Data.ByteString.Lazy (toStrict) +import qualified Data.ByteString as BS -import Paths_dex (getDataFileName) +-- import Paths_dex (getDataFileName) -import Actor import Live.Eval +import RenderHtml +import IncState +import Actor import TopLevel +import Types.Source runWeb :: FilePath -> EvalConfig -> TopStateEx -> IO () runWeb fname opts env = do - resultsChan <- watchAndEvalFile fname opts env + resultsChan <- watchAndEvalFile fname opts env >>= renderResults putStrLn "Streaming output to http://localhost:8000/" run 8000 $ serveResults resultsChan -serveResults :: ToJSON a => PChan (PChan a) -> Application +serveResults :: RenderedResultsServer -> Application serveResults resultsSubscribe request respond = do print (pathInfo request) case pathInfo request of @@ -44,16 +48,33 @@ serveResults resultsSubscribe request respond = do [("Content-Type", "text/plain")] "404 - Not Found" where respondWith dataFname ctype = do - fname <- getDataFileName dataFname + fname <- return dataFname -- lets us skip rebuilding during development + -- fname <- getDataFileName dataFname respond $ responseFile status200 [("Content-Type", ctype)] fname Nothing -resultStream :: ToJSON a => PChan (PChan a) -> StreamingBody -resultStream resultsSubscribe write flush = runActor \self -> do - write (makeBuilder ("start"::String)) >> flush - resultsSubscribe `sendPChan` (sendOnly self) - forever $ do msg <- readChan self - write (makeBuilder msg) >> flush +type RenderedResultsServer = StateServer (MonoidState RenderedResults) RenderedResults +type RenderedResults = CellsUpdate RenderedSourceBlock RenderedOutputs + +resultStream :: RenderedResultsServer -> StreamingBody +resultStream resultsServer write flush = do + sendUpdate ("start"::String) + (MonoidState initResult, resultsChan) <- subscribeIO resultsServer + sendUpdate initResult + forever $ readChan resultsChan >>= sendUpdate + where + sendUpdate :: ToJSON a => a -> IO () + sendUpdate x = write (fromByteString $ encodePacket x) >> flush -makeBuilder :: ToJSON a => a -> Builder -makeBuilder = fromByteString . toStrict . wrap . encode +encodePacket :: ToJSON a => a -> BS.ByteString +encodePacket = toStrict . wrap . encode where wrap s = "data:" <> s <> "\n\n" + +renderResults :: EvalServer -> IO RenderedResultsServer +renderResults evalServer = launchIncFunctionEvaluator evalServer + (\x -> (MonoidState $ renderEvalUpdate $ nodeListAsUpdate x, ())) + (\_ () dx -> (renderEvalUpdate dx, ())) + +renderEvalUpdate :: CellsUpdate SourceBlock Outputs -> CellsUpdate RenderedSourceBlock RenderedOutputs +renderEvalUpdate cellsUpdate = fmapCellsUpdate cellsUpdate + (\k b -> renderSourceBlock k b) + (\_ r -> renderOutputs r) diff --git a/src/lib/Logging.hs b/src/lib/Logging.hs deleted file mode 100644 index 1c3f0eef1..000000000 --- a/src/lib/Logging.hs +++ /dev/null @@ -1,84 +0,0 @@ --- Copyright 2019 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# LANGUAGE UndecidableInstances #-} - -module Logging (Logger, LoggerT (..), MonadLogger (..), logIO, runLoggerT, - FilteredLogger (..), logFiltered, logSkippingFilter, - MonadLogger1, MonadLogger2, - runLogger, execLogger, logThis, readLog, ) where - -import Control.Monad -import Control.Monad.Reader -import Data.Text.Prettyprint.Doc -import Control.Concurrent.MVar -import Prelude hiding (log) -import System.IO - -import Err -import Name - -data Logger l = Logger (MVar l) (Maybe Handle) - -data FilteredLogger k l = FilteredLogger (k -> Bool) (Logger l) - -runLogger :: (Monoid l, MonadIO m) => Maybe Handle -> (Logger l -> m a) -> m (a, l) -runLogger logFile m = do - log <- liftIO $ newMVar mempty - ans <- m $ Logger log logFile - logged <- liftIO $ readMVar log - return (ans, logged) - -execLogger :: (Monoid l, MonadIO m) => Maybe Handle -> (Logger l -> m a) -> m a -execLogger logFile m = fst <$> runLogger logFile m - -logThis :: (Pretty l, Monoid l, MonadIO m) => Logger l -> l -> m () -logThis (Logger log maybeLogHandle) x = liftIO $ do - forM_ maybeLogHandle \h -> do - hPutStrLn h $ pprint x - hFlush h - modifyMVar_ log \cur -> return (cur <> x) - -logFiltered :: (Monoid l, MonadIO m, Pretty l) => FilteredLogger k l -> k -> m l -> m () -logFiltered (FilteredLogger shouldLog logger) k m = - when (shouldLog k) $ m >>= logThis logger - -logSkippingFilter :: (Monoid l, MonadIO m, Pretty l) => FilteredLogger k l -> l -> m () -logSkippingFilter (FilteredLogger _ logger) = logThis logger - -readLog :: MonadIO m => Logger l -> m l -readLog (Logger log _) = liftIO $ readMVar log - --- === monadic interface === - -newtype LoggerT l m a = LoggerT { runLoggerT' :: ReaderT (Logger l) m a } - deriving (Functor, Applicative, Monad, MonadTrans, - MonadIO, MonadFail, Fallible, Catchable) - -class (Pretty l, Monoid l, Monad m) => MonadLogger l m | m -> l where - getLogger :: m (Logger l) - withLogger :: Logger l -> m a -> m a - -instance (MonadIO m, Pretty l, Monoid l) => MonadLogger l (LoggerT l m) where - getLogger = LoggerT ask - withLogger l m = LoggerT $ local (const l) $ runLoggerT' m - -type MonadLogger1 l (m :: MonadKind1) = forall (n::S) . MonadLogger l (m n) -type MonadLogger2 l (m :: MonadKind2) = forall (n1::S) (n2::S) . MonadLogger l (m n1 n2) - -logIO :: MonadIO m => MonadLogger l m => l -> m () -logIO val = do - logger <- getLogger - liftIO $ logThis logger val - -runLoggerT :: Monoid l => Logger l -> LoggerT l m a -> m a -runLoggerT l (LoggerT m) = runReaderT m l - --- === more instances === - -instance MonadLogger l m => MonadLogger l (ReaderT r m) where - getLogger = lift getLogger - withLogger l cont = ReaderT \r -> withLogger l $ runReaderT cont r diff --git a/src/lib/Lower.hs b/src/lib/Lower.hs index 5b8456ff6..db7b83fad 100644 --- a/src/lib/Lower.hs +++ b/src/lib/Lower.hs @@ -13,7 +13,6 @@ module Lower import Prelude hiding ((.)) import Data.Functor import Data.Maybe (fromJust) -import Data.List.NonEmpty qualified as NE import Control.Category import Control.Monad.Reader import Unsafe.Coerce @@ -27,6 +26,7 @@ import Name import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) @@ -59,45 +59,32 @@ import Util (enumerate) -- destination to a sub-block or sub-expression, hence "desintation -- passing style"). -type DestBlock = Abs (SBinder) SBlock +type DestBlock = Abs (SBinder) SExpr lowerFullySequential :: EnvReader m => Bool -> STopLam n -> m n (STopLam n) -lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM $ do - lam <- case wantDestStyle of - True -> do - refreshAbs (Abs bs body) \bs' body' -> do +lowerFullySequential wantDestStyle (TopLam False piTy (LamExpr bs body)) = liftEnvReaderM do + lam <- refreshAbs (Abs bs body) \bs' body' -> + liftAtomSubstBuilder case wantDestStyle of + True -> do xs <- bindersToAtoms bs' EffTy _ resultTy <- instantiate (sink piTy) xs - Abs b body'' <- lowerFullySequentialBlock resultTy body' - return $ LamExpr (bs' >>> UnaryNest b) body'' - False -> do - refreshAbs (Abs bs body) \bs' body' -> do - body'' <- lowerFullySequentialBlockNoDest body' - return $ LamExpr bs' body'' + let resultDestTy = RawRefTy resultTy + withFreshBinder "ans" resultDestTy \destBinder -> do + let dest = toAtom $ binderVar destBinder + LamExpr (bs' >>> UnaryNest destBinder) <$> buildBlock do + lowerExpr (Just (sink dest)) body' $> UnitVal + False -> LamExpr bs' <$> buildBlock (lowerExpr Nothing body') piTy' <- getLamExprType lam return $ TopLam wantDestStyle piTy' lam lowerFullySequential _ (TopLam True _ _) = error "already in destination style" -lowerFullySequentialBlock :: EnvReader m => SType n -> SBlock n -> m n (DestBlock n) -lowerFullySequentialBlock resultTy b = liftAtomSubstBuilder do - let resultDestTy = RawRefTy resultTy - withFreshBinder (getNameHint @String "ans") resultDestTy \destBinder -> do - Abs destBinder <$> buildBlock do - let dest = Var $ sink $ binderVar destBinder - lowerBlockWithDest dest b $> UnitVal -{-# SCC lowerFullySequentialBlock #-} - -lowerFullySequentialBlockNoDest :: EnvReader m => SBlock n -> m n (SBlock n) -lowerFullySequentialBlockNoDest b = liftAtomSubstBuilder $ buildBlock $ lowerBlock b -{-# SCC lowerFullySequentialBlockNoDest #-} - data LowerTag type LowerM = AtomSubstBuilder LowerTag SimpIR instance NonAtomRenamer (LowerM i o) i o where renameN = substM instance ExprVisitorEmits (LowerM i o) SimpIR i o where - visitExprEmits = lowerExpr + visitExprEmits = lowerExpr Nothing instance Visitor (LowerM i o) SimpIR i o where visitAtom = visitAtomDefault @@ -105,60 +92,45 @@ instance Visitor (LowerM i o) SimpIR i o where visitPi = visitPiDefault visitLam = visitLamEmits -lowerExpr :: Emits o => SExpr i -> LowerM i o (SAtom o) -lowerExpr expr = emitExpr =<< case expr of - TabCon Nothing ty els -> lowerTabCon Nothing ty els - PrimOp (Hof (TypedHof (EffTy _ resultTy) (For dir ixDict body))) -> do - resultTy' <- substM resultTy - lowerFor resultTy' Nothing dir ixDict body - -- this case is important because this pass changes effects - PrimOp (Hof (TypedHof _ hof)) -> - PrimOp . Hof <$> (visitGeneric hof >>= mkTypedHof) - Case e alts (EffTy _ ty) -> lowerCase Nothing e alts ty - _ -> visitGeneric expr - -lowerBlock :: Emits o => SBlock i -> LowerM i o (SAtom o) -lowerBlock = visitBlockEmits - -type Dest = Atom +type Dest = SAtom +type OptDest n = Maybe (Dest n) lowerFor :: Emits o - => SType o -> Maybe (Dest SimpIR o) -> ForAnn -> IxType SimpIR i -> LamExpr SimpIR i - -> LowerM i o (SExpr o) + => SType o -> OptDest o -> ForAnn -> IxType SimpIR i -> LamExpr SimpIR i + -> LowerM i o (SAtom o) lowerFor ansTy maybeDest dir ixTy (UnaryLamExpr (ib:>ty) body) = do ixTy' <- substM ixTy ty' <- substM ty case isSingletonType ansTy of True -> do body' <- buildUnaryLamExpr noHint (PairTy ty' UnitTy) \b' -> do - (i, _) <- fromPair $ Var b' - extendSubst (ib @> SubstVal i) $ lowerBlock body $> UnitVal + (i, _) <- fromPair $ toAtom b' + extendSubst (ib @> SubstVal i) $ lowerExpr Nothing body $> UnitVal void $ emitSeq dir ixTy' UnitVal body' - Atom . fromJust <$> singletonTypeVal ansTy + fromJust <$> singletonTypeVal ansTy False -> do - initDest <- ProdVal . (:[]) <$> case maybeDest of + initDest <- Con . ProdCon . (:[]) <$> case maybeDest of Just d -> return d - Nothing -> emitOp $ DAMOp $ AllocDest ansTy + Nothing -> emit $ AllocDest ansTy let destTy = getType initDest body' <- buildUnaryLamExpr noHint (PairTy ty' destTy) \b' -> do - (i, destProd) <- fromPair $ Var b' - dest <- normalizeProj (ProjectProduct 0) destProd - idest <- emitOp =<< mkIndexRef dest i - extendSubst (ib @> SubstVal i) $ lowerBlockWithDest idest body $> UnitVal - ans <- emitSeq dir ixTy' initDest body' >>= getProj 0 - return $ PrimOp $ DAMOp $ Freeze ans + (i, destProd) <- fromPair $ toAtom b' + dest <- proj 0 destProd + idest <- emit =<< mkIndexRef dest i + extendSubst (ib @> SubstVal i) $ lowerExpr (Just idest) body $> UnitVal + ans <- emitSeq dir ixTy' initDest body' >>= proj 0 + emit $ Freeze ans lowerFor _ _ _ _ _ = error "expected a unary lambda expression" -lowerTabCon :: forall i o. Emits o - => Maybe (Dest SimpIR o) -> SType i -> [SAtom i] -> LowerM i o (SExpr o) +lowerTabCon :: Emits o => OptDest o -> SType i -> [SAtom i] -> LowerM i o (SAtom o) lowerTabCon maybeDest tabTy elems = do - TabPi tabTy' <- substM tabTy + TyCon (TabPi tabTy') <- substM tabTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest $ TabPi tabTy' + Nothing -> emit $ AllocDest $ TyCon $ TabPi tabTy' Abs bord ufoBlock <- buildAbs noHint IdxRepTy \ord -> do - buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ Var $ sink ord + buildBlock $ unsafeFromOrdinal (sink $ tabIxType tabTy') $ toAtom $ sink ord -- This is emitting a chain of RememberDest ops to force `dest` to be used -- linearly, and to force reads of the `Freeze dest'` result not to be -- reordered in front of the writes. @@ -168,23 +140,23 @@ lowerTabCon maybeDest tabTy elems = do let go incoming_dest [] = return incoming_dest go incoming_dest ((ord, e):rest) = do i <- dropSubst $ extendSubst (bord@>SubstVal (IdxRepVal (fromIntegral ord))) $ - lowerBlock ufoBlock + lowerExpr Nothing ufoBlock carried_dest <- buildRememberDest "dest" incoming_dest \local_dest -> do - idest <- indexRef (Var local_dest) (sink i) - place (FullDest idest) =<< visitAtom e + idest <- indexRef (toAtom local_dest) (sink i) + place idest =<< visitAtom e return UnitVal go carried_dest rest dest' <- go dest (enumerate elems) - return $ PrimOp $ DAMOp $ Freeze dest' + emit $ Freeze dest' lowerCase :: Emits o - => Maybe (Dest SimpIR o) -> SAtom i -> [Alt SimpIR i] -> SType i - -> LowerM i o (SExpr o) + => OptDest o -> SAtom i -> [Alt SimpIR i] -> SType i + -> LowerM i o (SAtom o) lowerCase maybeDest scrut alts resultTy = do resultTy' <- substM resultTy dest <- case maybeDest of Just d -> return d - Nothing -> emitExpr $ PrimOp $ DAMOp $ AllocDest resultTy' + Nothing -> emit $ AllocDest resultTy' scrut' <- visitAtom scrut dest' <- buildRememberDest "case_dest" dest \local_dest -> do alts' <- forM alts \(Abs (b:>ty) body) -> do @@ -192,10 +164,10 @@ lowerCase maybeDest scrut alts resultTy = do buildAbs (getNameHint b) ty' \b' -> extendSubst (b @> Rename (atomVarName b')) $ buildBlock do - lowerBlockWithDest (Var $ sink $ local_dest) body $> UnitVal - void $ mkCase (sink scrut') UnitTy alts' >>= emitExpr + lowerExpr (Just (toAtom $ sink $ local_dest)) body $> UnitVal + void $ mkCase (sink scrut') UnitTy alts' >>= emit return UnitVal - return $ PrimOp $ DAMOp $ Freeze dest' + emit $ Freeze dest' -- Destination-passing traversals -- @@ -217,17 +189,9 @@ lowerCase maybeDest scrut alts resultTy = do -- so that it never allocates scratch space for its result, but will put it directly in -- the corresponding slice of the full 2D buffer. -type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (ProjDest o) i' +type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (Dest o) i' -data ProjDest o - = FullDest (Dest SimpIR o) - | ProjDest (NE.NonEmpty Projection) (Dest SimpIR o) -- dest corresponds to the projection applied to name - deriving (Show) - -instance SinkableE ProjDest where - sinkingProofE = todoSinkableProof - -lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o) +lookupDest :: DestAssignment i' o -> SAtomName i' -> OptDest o lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- Matches up the free variables of the atom, with the given dest. For example, if the @@ -237,124 +201,108 @@ lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests -- as much as possible, but it can lead to unnecessary copies being done at run-time. -- -- XXX: When adding more cases, be careful about potentially repeated vars in the output! -decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o)) +decomposeDest :: Emits o => Dest o -> SExpr i' -> LowerM i o (Maybe (DestAssignment i' o)) decomposeDest dest = \case - Var v -> return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest - ProjectElt _ p x -> do - (ps, v) <- return $ asNaryProj p x - return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ ProjDest ps dest + Atom (Stuck _ (Var v)) -> + return $ Just $ singletonNameMapE (atomVarName v) $ LiftE dest _ -> return Nothing -lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o) -lowerBlockWithDest dest (Abs decls ans) = do - decomposeDest dest ans >>= \case - Nothing -> do - ans' <- visitDeclsEmits decls $ visitAtom ans - place (FullDest dest) ans' - return ans' - Just destMap -> do - s <- getSubst - case isDistinctNest decls of - Nothing -> error "Non-distinct decls?" - Just DistinctBetween -> do - s' <- traverseDeclNestWithDestS destMap s decls - -- But we have to emit explicit writes, for all the vars that are not defined in decls! - forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do - x <- case s ! n of - Rename v -> Var <$> toAtomVar v - SubstVal a -> return a - place d x - withSubst s' $ substM ans - traverseDeclNestWithDestS :: forall i i' l o. (Emits o, DistinctBetween l i') => DestAssignment i' o -> Subst AtomSubstVal l o -> Nest (Decl SimpIR) l i' -> LowerM i o (Subst AtomSubstVal i' o) traverseDeclNestWithDestS destMap s = \case Empty -> return s - Nest (Let b (DeclBinding ann expr)) rest -> do + Nest (Let b (DeclBinding _ expr)) rest -> do DistinctBetween <- return $ withExtEvidence rest $ shortenBetween @i' b let maybeDest = lookupDest destMap $ sinkBetween $ binderName b - expr' <- withSubst s $ lowerExprWithDest maybeDest expr - v <- emitDecl (getNameHint b) ann expr' - traverseDeclNestWithDestS destMap (s <>> (b @> Rename (atomVarName v))) rest - -lowerExprWithDest :: forall i o. Emits o => Maybe (ProjDest o) -> SExpr i -> LowerM i o (SExpr o) -lowerExprWithDest dest expr = case expr of - TabCon Nothing ty els -> lowerTabCon tabDest ty els + result <- withSubst s $ lowerExpr maybeDest expr + traverseDeclNestWithDestS destMap (s <>> (b @> SubstVal result)) rest + +traverseDeclNest :: Emits o => Nest SDecl i i' -> LowerM i' o a -> LowerM i o a +traverseDeclNest decls cont = case decls of + Empty -> cont + Nest (Let b (DeclBinding _ expr)) rest -> do + x <- lowerExpr Nothing expr + extendSubst (b@>SubstVal x) $ traverseDeclNest rest cont + +lowerExpr :: forall i o. Emits o => OptDest o -> SExpr i -> LowerM i o (SAtom o) +lowerExpr dest expr = case expr of + Block _ (Abs decls result) -> case dest of + Nothing -> traverseDeclNest decls $ lowerExpr Nothing result + Just dest' -> do + decomposeDest dest' result >>= \case + Nothing -> do + traverseDeclNest decls do + lowerExpr (Just dest') result + Just destMap -> do + s <- getSubst + case isDistinctNest decls of + Nothing -> error "Non-distinct decls?" + Just DistinctBetween -> do + s' <- traverseDeclNestWithDestS destMap s decls + -- But we have to emit explicit writes, for all the vars that are not defined in decls! + forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do + x <- case s ! n of + Rename v -> toAtom <$> toAtomVar v + SubstVal a -> return a + place d x + withSubst s' (substM result) >>= emit + TabCon Nothing ty els -> lowerTabCon dest ty els PrimOp (Hof (TypedHof (EffTy _ ansTy) (For dir ixDict body))) -> do ansTy' <- substM ansTy - lowerFor ansTy' tabDest dir ixDict body + lowerFor ansTy' dest dir ixDict body PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter Nothing m body))) -> do PairTy _ ansTy <- visitType ty traverseRWS ansTy body \ref' body' -> do m' <- visitGeneric m - return $ RunWriter ref' m' body' + emitHof $ RunWriter ref' m' body' PrimOp (Hof (TypedHof (EffTy _ ty) (RunState Nothing s body))) -> do PairTy _ ansTy <- visitType ty traverseRWS ansTy body \ref' body' -> do s' <- visitAtom s - return $ RunState ref' s' body' + emitHof $ RunState ref' s' body' -- this case is important because this pass changes effects PrimOp (Hof (TypedHof _ hof)) -> do - hof' <- PrimOp . Hof <$> (visitGeneric hof >>= mkTypedHof) + hof' <- emit =<< (visitGeneric hof >>= mkTypedHof) placeGeneric hof' - Case e alts (EffTy _ ty) -> case dest of - Nothing -> lowerCase Nothing e alts ty - Just (FullDest d) -> lowerCase (Just d) e alts ty - Just d -> do - ans <- lowerCase Nothing e alts ty >>= emitExprToAtom - place d ans - return $ Atom ans + Case e alts (EffTy _ ty) -> lowerCase dest e alts ty _ -> generic where - tabDest = dest <&> \case FullDest d -> d; ProjDest _ _ -> error "unexpected projection" - - generic = visitGeneric expr >>= placeGeneric + generic :: LowerM i o (SAtom o) + generic = visitGeneric expr >>= emit >>= placeGeneric + placeGeneric :: SAtom o -> LowerM i o (SAtom o) placeGeneric e = do case dest of Nothing -> return e Just d -> do - ans <- Var <$> emit e - place d ans - return $ Atom ans + place d e + return e traverseRWS :: SType o -> LamExpr SimpIR i - -> (Maybe (Dest SimpIR o) -> LamExpr SimpIR o -> LowerM i o (Hof SimpIR o)) - -> LowerM i o (SExpr o) + -> (OptDest o -> LamExpr SimpIR o -> LowerM i o (SAtom o)) + -> LowerM i o (SAtom o) traverseRWS referentTy (LamExpr (BinaryNest hb rb) body) cont = do unpackRWSDest dest >>= \case Nothing -> generic Just (bodyDest, refDest) -> do - hof <- cont refDest =<< + cont refDest =<< buildEffLam (getNameHint rb) referentTy \hb' rb' -> extendRenamer (hb@>atomVarName hb' <.> rb@>atomVarName rb') do - case bodyDest of - Nothing -> lowerBlock body - Just bd -> lowerBlockWithDest (sink bd) body - PrimOp . Hof <$> mkTypedHof hof - + lowerExpr (sink <$> bodyDest) body traverseRWS _ _ _ = error "Expected a binary lambda expression" unpackRWSDest = \case Nothing -> return Nothing - Just d -> case d of - FullDest fd -> do - bd <- getProjRef (ProjectProduct 0) fd - rd <- getProjRef (ProjectProduct 1) fd - return $ Just (Just bd, Just rd) - ProjDest (ProjectProduct 0 NE.:| []) pd -> return $ Just (Just pd, Nothing) - ProjDest (ProjectProduct 1 NE.:| []) pd -> return $ Just (Nothing, Just pd) - ProjDest _ _ -> return Nothing - -place :: Emits o => ProjDest o -> SAtom o -> LowerM i o () -place pd x = case pd of - FullDest d -> void $ emitOp $ DAMOp $ Place d x - ProjDest p d -> do - x' <- normalizeNaryProj (NE.toList p) x - void $ emitOp $ DAMOp $ Place d x' + Just d -> do + bd <- getProjRef (ProjectProduct 0) d + rd <- getProjRef (ProjectProduct 1) d + return $ Just (Just bd, Just rd) + +place :: Emits o => Dest o -> SAtom o -> LowerM i o () +place d x = void $ emit $ Place d x -- === Extensions to the name system === diff --git a/src/lib/MTL1.hs b/src/lib/MTL1.hs index 56fb1cdba..47fe8b8c1 100644 --- a/src/lib/MTL1.hs +++ b/src/lib/MTL1.hs @@ -6,27 +6,18 @@ {-# LANGUAGE UndecidableInstances #-} -module MTL1 ( - MonadTrans11 (..), HoistableState (..), - WriterT1, pattern WriterT1, runWriterT1, runWriterT1From, - StateT1, pattern StateT1, runStateT1, evalStateT1, MonadState1, - MaybeT1 (..), runMaybeT1, ReaderT1 (..), runReaderT1, - ScopedT1, pattern ScopedT1, runScopedT1, - FallibleT1, runFallibleT1, - runStreamWriterT1, StreamWriter (..), StreamWriterT1 (..), - runStreamReaderT1, StreamReader (..), StreamReaderT1 (..), - ) where +module MTL1 where import Control.Monad.Reader import Control.Monad.Writer.Class import Control.Monad.State.Strict import Control.Monad.Trans.Maybe -import qualified Control.Monad.Trans.Except as MTE import Control.Applicative import Data.Foldable (toList) import Name import Err +import Types.Top (Env) import Core (EnvReader (..), EnvExtender (..)) import Util (SnocList (..), snoc, emptySnocList) @@ -117,6 +108,14 @@ deriving instance MonadWriter s (m n) => MonadWriter s (ReaderT1 r m n) deriving instance MonadState s (m n) => MonadState s (ReaderT1 r m n) +instance (Monad1 m, Alternative1 m) => Alternative ((ReaderT1 r m) n) where + empty = lift11 empty + {-# INLINE empty #-} + ReaderT1 (ReaderT m1) <|> ReaderT1 (ReaderT m2) = + ReaderT1 $ ReaderT \r -> m1 r <|> m2 r + {-# INLINE (<|>) #-} + + instance (SinkableE r, EnvReader m) => EnvReader (ReaderT1 r m) where unsafeGetEnv = lift11 unsafeGetEnv {-# INLINE unsafeGetEnv #-} @@ -136,18 +135,11 @@ instance (SinkableE r, EnvExtender m) => EnvExtender (ReaderT1 r m) where refreshAbs ab \b e -> runReaderT1 (sink r) $ cont b e instance (Monad1 m, Fallible (m n)) => Fallible (ReaderT1 r m n) where - throwErrs = lift11 . throwErrs - addErrCtx ctx (ReaderT1 m) = ReaderT1 $ addErrCtx ctx m - {-# INLINE addErrCtx #-} + throwErr = lift11 . throwErr instance (Monad1 m, Catchable (m n)) => Catchable (ReaderT1 s m n) where catchErr (ReaderT1 m) f = ReaderT1 $ catchErr m (runReaderT1' . f) -instance (Monad1 m, CtxReader (m n)) => CtxReader (ReaderT1 s m n) where - getErrCtx = lift11 getErrCtx - {-# INLINE getErrCtx #-} - - -------------------- StateT1 -------------------- newtype StateT1 (s :: E) (m :: MonadKind1) (n :: S) (a :: *) = @@ -193,16 +185,16 @@ instance (SinkableE s, ScopeReader m) => ScopeReader (StateT1 s m) where {-# INLINE getDistinct #-} instance (Monad1 m, Fallible (m n)) => Fallible (StateT1 s m n) where - throwErrs = lift11 . throwErrs - addErrCtx ctx (WrapStateT1 m) = WrapStateT1 $ addErrCtx ctx m - {-# INLINE addErrCtx #-} + throwErr = lift11 . throwErr instance (Monad1 m, Catchable (m n)) => Catchable (StateT1 s m n) where catchErr (WrapStateT1 m) f = WrapStateT1 $ catchErr m (runStateT1' . f) -instance (Monad1 m, CtxReader (m n)) => CtxReader (StateT1 s m n) where - getErrCtx = lift11 getErrCtx - {-# INLINE getErrCtx #-} +instance (Monad1 m, Alternative1 m) => Alternative ((StateT1 s m) n) where + empty = lift11 empty + {-# INLINE empty #-} + StateT1 m1 <|> StateT1 m2 = StateT1 \s -> m1 s <|> m2 s + {-# INLINE (<|>) #-} class HoistableState (s::E) where hoistState :: BindsNames b => s n -> b n l -> s l -> s n @@ -253,7 +245,6 @@ runScopedT1 m s = fst <$> runStateT1 (runScopedT1' m) s deriving instance (Monad1 m, Fallible1 m) => Fallible (ScopedT1 s m n) deriving instance (Monad1 m, Catchable1 m) => Catchable (ScopedT1 s m n) -deriving instance (Monad1 m, CtxReader1 m) => CtxReader (ScopedT1 s m n) instance (SinkableE s, EnvExtender m) => EnvExtender (ScopedT1 s m) where refreshAbs ab cont = ScopedT1 \s -> do @@ -279,9 +270,7 @@ instance Monad (m n) => MonadFail (MaybeT1 m n) where {-# INLINE fail #-} instance Monad (m n) => Fallible (MaybeT1 m n) where - throwErrs _ = empty - addErrCtx _ cont = cont - {-# INLINE addErrCtx #-} + throwErr _ = empty instance EnvReader m => EnvReader (MaybeT1 m) where unsafeGetEnv = lift11 unsafeGetEnv @@ -297,39 +286,6 @@ instance EnvExtender m => EnvExtender (MaybeT1 m) where refreshAbs ab cont = MaybeT1 $ MaybeT $ refreshAbs ab \b e -> runMaybeT $ runMaybeT1' $ cont b e --------------------- FallibleT1 -------------------- - -newtype FallibleT1 (m::MonadKind1) (n::S) a = - FallibleT1 { fromFallibleT :: ReaderT ErrCtx (MTE.ExceptT Errs (m n)) a } - deriving (Functor, Applicative, Monad) - -runFallibleT1 :: Monad1 m => FallibleT1 m n a -> m n (Except a) -runFallibleT1 m = - MTE.runExceptT (runReaderT (fromFallibleT m) mempty) >>= \case - Right ans -> return $ Success ans - Left errs -> return $ Failure errs -{-# INLINE runFallibleT1 #-} - -instance Monad1 m => MonadFail (FallibleT1 m n) where - fail s = throw MonadFailErr s - {-# INLINE fail #-} - -instance Monad1 m => Fallible (FallibleT1 m n) where - throwErrs (Errs errs) = FallibleT1 $ ReaderT \ambientCtx -> - MTE.throwE $ Errs [Err errTy (ambientCtx <> ctx) s | Err errTy ctx s <- errs] - addErrCtx ctx (FallibleT1 m) = FallibleT1 $ local (<> ctx) m - {-# INLINE addErrCtx #-} - -instance ScopeReader m => ScopeReader (FallibleT1 m) where - unsafeGetScope = FallibleT1 $ lift $ lift unsafeGetScope - {-# INLINE unsafeGetScope #-} - getDistinct = FallibleT1 $ lift $ lift $ getDistinct - {-# INLINE getDistinct #-} - -instance EnvReader m => EnvReader (FallibleT1 m) where - unsafeGetEnv = FallibleT1 $ lift $ lift unsafeGetEnv - {-# INLINE unsafeGetEnv #-} - -------------------- StreamWriter -------------------- class Monad m => StreamWriter w m | m -> w where @@ -370,3 +326,90 @@ runStreamReaderT1 rs m = do (ans, LiftE rsRemaining) <- runStateT1 (runStreamReaderT1' m) (LiftE rs) return (ans, rsRemaining) {-# INLINE runStreamReaderT1 #-} + +-------------------- DiffState -------------------- + +class MonoidE (d::E) where + emptyE :: d n + catE :: d n -> d n -> d n + +class MonoidE d => DiffStateE (s::E) (d::E) where + updateDiffStateE :: Distinct n => Env n -> s n -> d n -> s n + +newtype DiffStateT1 (s::E) (d::E) (m::MonadKind1) (n::S) (a:: *) = + DiffStateT1' { runDiffStateT1'' :: StateT (s n, d n) (m n) a } + deriving ( Functor, Applicative, Monad, MonadFail, MonadIO + , Fallible, Catchable) + +pattern DiffStateT1 :: ((s n, d n) -> m n (a, (s n, d n))) -> DiffStateT1 s d m n a +pattern DiffStateT1 cont = DiffStateT1' (StateT cont) + +diffStateT1 + :: (EnvReader m, DiffStateE s d, MonoidE d) + => (s n -> m n (a, d n)) -> DiffStateT1 s d m n a +diffStateT1 cont = DiffStateT1 \(s, d) -> do + (ans, d') <- cont s + env <- unsafeGetEnv + Distinct <- getDistinct + return (ans, (updateDiffStateE env s d', catE d d')) +{-# INLINE diffStateT1 #-} + +runDiffStateT1 + :: (EnvReader m, DiffStateE s d, MonoidE d) + => s n -> DiffStateT1 s d m n a -> m n (a, d n) +runDiffStateT1 s (DiffStateT1' (StateT cont)) = do + (ans, (_, d)) <- cont (s, emptyE) + return (ans, d) +{-# INLINE runDiffStateT1 #-} + +class (Monad1 m, MonoidE d) + => MonadDiffState1 (m::MonadKind1) (s::E) (d::E) | m -> s, m -> d where + withDiffState :: s n -> m n a -> m n (a, d n) + updateDiffStateM :: d n -> m n () + getDiffState :: m n (s n) + +instance (EnvReader m, DiffStateE s d, MonoidE d) => MonadDiffState1 (DiffStateT1 s d m) s d where + getDiffState = DiffStateT1' $ fst <$> get + {-# INLINE getDiffState #-} + + withDiffState s cont = DiffStateT1' do + (sOld, dOld) <- get + put (s, emptyE) + ans <- runDiffStateT1'' cont + (_, dLocal) <- get + put (sOld, dOld) + return (ans, dLocal) + {-# INLINE withDiffState #-} + + updateDiffStateM d = DiffStateT1' do + (s, d') <- get + env <- lift unsafeGetEnv + Distinct <- lift getDistinct + put (updateDiffStateE env s d, catE d d') + {-# INLINE updateDiffStateM #-} + +instance MonoidE (ListE e) where + emptyE = mempty + catE = (<>) + +instance MonoidE (RListE e) where + emptyE = mempty + catE = (<>) + +instance (Monad1 m, Alternative1 m, MonoidE d) => Alternative ((DiffStateT1 s d m) n) where + empty = DiffStateT1' $ StateT \_ -> empty + {-# INLINE empty #-} + DiffStateT1' (StateT m1) <|> DiffStateT1' (StateT m2) = DiffStateT1' $ StateT \s -> + m1 s <|> m2 s + {-# INLINE (<|>) #-} + +instance (ScopeReader m, MonoidE d) => ScopeReader (DiffStateT1 s d m) where + unsafeGetScope = lift11 unsafeGetScope + getDistinct = lift11 getDistinct + +instance (EnvReader m, MonoidE d) => EnvReader (DiffStateT1 s d m) where + unsafeGetEnv = lift11 unsafeGetEnv + +instance MonadTrans11 (DiffStateT1 s d) where + lift11 m = DiffStateT1' $ lift m + {-# INLINE lift11 #-} diff --git a/src/lib/MonadUtil.hs b/src/lib/MonadUtil.hs new file mode 100644 index 000000000..17a21bd95 --- /dev/null +++ b/src/lib/MonadUtil.hs @@ -0,0 +1,96 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} + +module MonadUtil ( + DefuncState (..), LabelReader (..), SingletonLabel (..), FreshNames (..), + runFreshNameT, FreshNameT (..), Logger (..), LogLevel (..), getIOLogger, + IOLoggerT (..), runIOLoggerT, LoggerT (..), runLoggerT, IOLogger (..), HasIOLogger (..)) where + +import Control.Monad.Reader +import Control.Monad.State.Strict +import Control.Monad.Writer.Strict + +import Err + +-- === Defunctionalized state === +-- Interface for state whose allowable updates are specified by a data type. +-- Useful for `IncState`, for specifying read-only env components, or +-- generally for specifying certain constraints on updates. + +class DefuncState d m | m -> d where + update :: d -> m () + +class LabelReader (l :: * -> *) m | m -> l where + getl :: l a -> m a + +data SingletonLabel a b where + It :: SingletonLabel a a + +-- === Fresh name monad === + +-- Used for ad-hoc names with no nested binders that don't need to be treated +-- carefully using the whole "foil" name system. + +class Monad m => FreshNames a m | m -> a where + freshName :: m a + +newtype FreshNameT m a = FreshNameT { runFreshNameT' :: StateT Int m a } + deriving (Functor, Applicative, Monad, MonadIO) + +instance MonadIO m => FreshNames Int (FreshNameT m) where + freshName = FreshNameT do + fresh <- get + put (fresh + 1) + return fresh + +instance FreshNames a m => FreshNames a (ReaderT r m) where + freshName = lift freshName + +runFreshNameT :: MonadIO m => FreshNameT m a -> m a +runFreshNameT cont = evalStateT (runFreshNameT' cont) 0 + +-- === Logging monad === + +data IOLogger w = IOLogger { ioLogLevel :: LogLevel + , ioLogAction :: w -> IO () } +data LogLevel = NormalLogLevel | DebugLogLevel + +class (Monoid w, Monad m) => Logger w m | m -> w where + emitLog :: w -> m () + getLogLevel :: m LogLevel + +newtype IOLoggerT w m a = IOLoggerT { runIOLoggerT' :: ReaderT (IOLogger w) m a } + deriving (Functor, Applicative, Monad, MonadIO, Fallible, MonadFail, Catchable) + +class Monad m => HasIOLogger w m | m -> w where + getIOLogAction :: Monad m => m (w -> IO ()) + +instance (Monoid w, MonadIO m) => HasIOLogger w (IOLoggerT w m) where + getIOLogAction = IOLoggerT $ asks ioLogAction + +instance (Monoid w, MonadIO m) => Logger w (IOLoggerT w m) where + emitLog w = do + logger <- getIOLogAction + liftIO $ logger w + getLogLevel = IOLoggerT $ asks ioLogLevel + +getIOLogger :: (HasIOLogger w m, Logger w m) => m (IOLogger w) +getIOLogger = IOLogger <$> getLogLevel <*> getIOLogAction + +runIOLoggerT :: (Monoid w, MonadIO m) => LogLevel -> (w -> IO ()) -> IOLoggerT w m a -> m a +runIOLoggerT logLevel write cont = runReaderT (runIOLoggerT' cont) (IOLogger logLevel write) + +newtype LoggerT w m a = LoggerT { runLoggerT' :: WriterT w m a } + deriving (Functor, Applicative, Monad, MonadIO) + +instance (Monoid w, Monad m) => Logger w (LoggerT w m) where + emitLog w = LoggerT $ tell w + getLogLevel = return NormalLogLevel + +runLoggerT :: (Monoid w, Monad m) => LoggerT w m a -> m (a, w) +runLoggerT cont = runWriterT (runLoggerT' cont) diff --git a/src/lib/Name.hs b/src/lib/Name.hs index 94b14d98c..9aa9402ca 100644 --- a/src/lib/Name.hs +++ b/src/lib/Name.hs @@ -43,7 +43,8 @@ import qualified Unsafe.Coerce as TrulyUnsafe import RawName ( RawNameMap, RawName, NameHint, HasNameHint (..) , freshRawName, rawNameFromHint, rawNames, noHint) import qualified RawName as R -import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..) ) +import Util ( zipErr, onFst, onSnd, transitiveClosure, SnocList (..), unsnoc ) +import PPrint import Err import IRVariants @@ -228,7 +229,7 @@ class SinkableB b => RenameB (b::B) where class (SinkableV v , forall c. Color c => RenameE (v c)) => RenameV (v::V) -type HasNamesE e = (RenameE e, HoistableE e) +type HasNamesE e = (RenameE e, SinkableE e, HoistableE e) type HasNamesB = RenameB instance RenameV Name @@ -247,6 +248,17 @@ instance Color c => RenameB (NameBinder c) where _ -> sink env <>> b @> (fromName $ binderName b') cont (scope', env') b' +-- === E-kinded functor === + +class FunctorE (f::E -> E) where + fmapE :: (forall l. e l -> e' l) -> f e n -> f e' n + +instance FunctorE ListE where + fmapE f (ListE xs) = ListE (fmap f xs) + +instance FunctorE (Abs b) where + fmapE f (Abs b e) = Abs b (f e) + -- === monadic type classes for reading and extending envs and scopes === data WithScope (e::E) (n::S) where @@ -262,6 +274,10 @@ class Monad1 m => ScopeReader (m::MonadKind1) where unsafeGetScope :: m n (Scope n) getDistinct :: m n (DistinctEvidence n) +withDistinct :: ScopeReader m => (Distinct n => m n a) -> m n a +withDistinct cont = getDistinct >>= \Distinct -> cont +{-# INLINE withDistinct #-} + class ScopeReader m => ScopeExtender (m::MonadKind1) where -- We normally use the EnvReader version, `refreshAbs`, but sometime we're -- working with raw binders that don't have env information associated with @@ -430,6 +446,9 @@ type OrdE e = (forall (n::S) . Ord (e n )) :: Constraint type OrdV v = (forall (c::C) (n::S). Ord (v c n)) :: Constraint type OrdB b = (forall (n::S) (l::S). Ord (b n l)) :: Constraint +type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint +type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint + type HashableE (e::E) = forall n. Hashable (e n) data UnitE (n::S) = UnitE @@ -470,6 +489,9 @@ forgetEitherE (RightE x) = x newtype ListE (e::E) (n::S) = ListE { fromListE :: [e n] } deriving (Show, Eq, Generic) +newtype RListE (e::E) (n::S) = RListE { fromRListE :: (SnocList (e n)) } + deriving (Show, Eq, Generic) + newtype MapE (k::E) (v::E) (n::S) = MapE { fromMapE :: M.Map (k n) (v n) } deriving (Semigroup, Monoid) @@ -525,6 +547,9 @@ data WithAttrB (a:: *) (b::B) (n::S) (l::S) = WithAttrB {getAttr :: a , withoutAttr :: b n l } deriving (Show, Generic) +pattern ZipB :: [a] -> Nest b n l -> Nest (WithAttrB a b) n l +pattern ZipB attrs bs <- (unzipAttrs -> (attrs, bs)) + unzipAttrs :: Nest (WithAttrB a b) n l -> ([a], Nest b n l) unzipAttrs Empty = ([], Empty) unzipAttrs (Nest (WithAttrB a b) rest) = (a:as, Nest b bs) @@ -860,12 +885,6 @@ type MonadIO2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadIO (m n l) type Catchable1 (m :: MonadKind1) = forall (n::S) . Catchable (m n ) type Catchable2 (m :: MonadKind2) = forall (n::S) (l::S) . Catchable (m n l) -type Searcher1 (m :: MonadKind1) = forall (n::S) . Searcher (m n ) -type Searcher2 (m :: MonadKind2) = forall (n::S) (l::S) . Searcher (m n l) - -type CtxReader1 (m :: MonadKind1) = forall (n::S) . CtxReader (m n ) -type CtxReader2 (m :: MonadKind2) = forall (n::S) (l::S) . CtxReader (m n l) - type MonadFail1 (m :: MonadKind1) = forall (n::S) . MonadFail (m n ) type MonadFail2 (m :: MonadKind2) = forall (n::S) (l::S) . MonadFail (m n l) @@ -1316,12 +1335,6 @@ instance (Monad1 m, Alternative (m n)) => Alternative (OutReaderT e m n) where f1 env <|> f2 env {-# INLINE (<|>) #-} -instance Searcher1 m => Searcher (OutReaderT e m n) where - OutReaderT (ReaderT f1) OutReaderT (ReaderT f2) = - OutReaderT $ ReaderT \env -> - f1 env f2 env - {-# INLINE () #-} - instance MonadWriter w (m n) => MonadWriter w (OutReaderT e m n) where tell w = OutReaderT $ lift $ tell w {-# INLINE tell #-} @@ -1549,14 +1562,7 @@ instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, Fallible m) => Fallible (InplaceT bindings decls m n) where - throwErrs errs = UnsafeMakeInplaceT \_ _ -> throwErrs errs - addErrCtx ctx cont = UnsafeMakeInplaceT \env decls -> - addErrCtx ctx $ unsafeRunInplaceT cont env decls - {-# INLINE addErrCtx #-} - -instance (ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m, CtxReader m) - => CtxReader (InplaceT bindings decls m n) where - getErrCtx = lift1 getErrCtx + throwErr errs = UnsafeMakeInplaceT \_ _ -> throwErr errs instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m , Alternative m) @@ -1567,13 +1573,6 @@ instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Monad m f1 env decls <|> f2 env decls {-# INLINE (<|>) #-} -instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, - Monad m, Alternative m, Searcher m) - => Searcher (InplaceT bindings decls m n) where - UnsafeMakeInplaceT f1 UnsafeMakeInplaceT f2 = UnsafeMakeInplaceT \env decls -> - f1 env decls f2 env decls - {-# INLINE () #-} - instance ( ExtOutMap bindings decls, BindsNames decls, SinkableB decls, Catchable m) => Catchable (InplaceT bindings decls m n) where @@ -1632,7 +1631,7 @@ newtype DoubleInplaceT (bindings::E) (d1::B) (d2::B) (m::MonadKind) (n::S) (a :: { unsafeRunDoubleInplaceT :: StateT (Scope UnsafeS, d1 UnsafeS UnsafeS) (InplaceT bindings d2 m n) a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible - , CtxReader, MonadWriter w, MonadReader r, MonadIO, Catchable) + , MonadWriter w, MonadReader r, MonadIO, Catchable) liftDoubleInplaceT :: (Monad m, ExtOutMap bindings d2, OutFrag d2) @@ -2004,23 +2003,40 @@ instance (SinkableE k, SinkableE v, OrdE k) => SinkableE (MapE k v) where itemsE = ListE $ toPairE <$> M.toList m newItems = fromPairE <$> (fromListE $ sinkingProofE fresh itemsE) -instance SinkableE e => SinkableE (ListE e) where - sinkingProofE fresh (ListE xs) = ListE $ map (sinkingProofE fresh) xs - instance SinkableE e => SinkableE (NonEmptyListE e) where sinkingProofE fresh (NonEmptyListE xs) = NonEmptyListE $ fmap (sinkingProofE fresh) xs +instance SinkableE e => SinkableE (ListE e) where + sinkingProofE fresh (ListE xs) = ListE $ map (sinkingProofE fresh) xs + instance AlphaEqE e => AlphaEqE (ListE e) where alphaEqE (ListE xs) (ListE ys) | length xs == length ys = mapM_ (uncurry alphaEqE) (zip xs ys) | otherwise = zipErr instance Monoid (ListE e n) where - mempty = ListE [] + mempty = ListE mempty instance Semigroup (ListE e n) where ListE xs <> ListE ys = ListE $ xs <> ys +instance SinkableE e => SinkableE (RListE e) where + sinkingProofE fresh (RListE xs) = RListE $ fmap (sinkingProofE fresh) xs + +instance RenameE e => RenameE (RListE e) where + renameE env (RListE xs) = RListE $ fmap (renameE env) xs + +instance AlphaEqE e => AlphaEqE (RListE e) where + alphaEqE (RListE xs) (RListE ys) + | length xs == length ys = mapM_ (uncurry alphaEqE) (zip (fromReversedList xs) (fromReversedList ys)) + | otherwise = zipErr + +instance Monoid (RListE e n) where + mempty = RListE mempty + +instance Semigroup (RListE e n) where + RListE xs <> RListE ys = RListE $ xs <> ys + instance (EqE k, HashableE k) => GenericE (HashMapE k v) where type RepE (HashMapE k v) = ListE (PairE k v) fromE (HashMapE m) = ListE $ map (uncurry PairE) $ HM.toList m @@ -2149,6 +2165,11 @@ instance (PrettyE e1, PrettyE e2) => Pretty (EitherE e1 e2 n) where instance PrettyE e => Pretty (ListE e n) where pretty (ListE e) = pretty e +instance PrettyE e => Pretty (RListE e n) where + pretty (RListE e) = pretty $ unsnoc e + +deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) + instance ( Generic (b UnsafeS UnsafeS) , Generic (body UnsafeS) ) => Generic (Abs b body n) where @@ -2731,19 +2752,21 @@ canonicalizeForPrinting e cont = do ClosedWithScope scope e' -> cont $ renameE (scope, newSubst id) e' -liftHoistExcept :: Fallible m => HoistExcept a -> m a -liftHoistExcept (HoistSuccess x) = return x -liftHoistExcept (HoistFailure vs) = throw EscapedNameErr (pprint vs) +pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String +pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' -liftHoistExcept' :: Fallible m => String -> HoistExcept a -> m a -liftHoistExcept' _ (HoistSuccess x) = return x -liftHoistExcept' msg (HoistFailure vs) = - throw EscapedNameErr $ (pprint vs) ++ "\n" ++ msg +liftHoistExcept :: Fallible m => SrcId -> HoistExcept a -> m a +liftHoistExcept _ (HoistSuccess x) = return x +liftHoistExcept sid (HoistFailure vs) = throw sid $ EscapedNameErr $ map pprint vs ignoreHoistFailure :: HasCallStack => HoistExcept a -> a ignoreHoistFailure (HoistSuccess x) = x ignoreHoistFailure (HoistFailure _) = error "hoist failure" +-- TODO: make this a no-op in the non-debug build +hardHoist :: (HasCallStack, BindsNames b, HoistableE e) => b n l -> e l -> e n +hardHoist b e = ignoreHoistFailure $ hoist b e + hoist :: (BindsNames b, HoistableE e) => b n l -> e l -> HoistExcept (e n) hoist b e = case R.disjoint fvs frag of @@ -2822,10 +2845,11 @@ exchangeBs (PairB b1 b2) = partitionBinders :: forall b b1 b2 m n l - . (SinkableB b2, HoistableB b1, BindsNames b2, Fallible m, Distinct l) => Nest b n l + . (SinkableB b2, HoistableB b1, BindsNames b2, Fallible m, Distinct l) + => SrcId -> Nest b n l -> (forall n' l'. b n' l' -> m (EitherB b1 b2 n' l')) -> m (PairB (Nest b1) (Nest b2) n l) -partitionBinders bs assignBinder = go bs where +partitionBinders sid bs assignBinder = go bs where go :: Distinct l' => Nest b n' l' -> m (PairB (Nest b1) (Nest b2) n' l') go = \case Empty -> return $ PairB Empty Empty @@ -2836,7 +2860,7 @@ partitionBinders bs assignBinder = go bs where RightB b2 -> withSubscopeDistinct bs2 case exchangeBs (PairB b2 bs1) of HoistSuccess (PairB bs1' b2') -> return $ PairB bs1' (Nest b2' bs2) - HoistFailure vs -> throw EscapedNameErr $ (pprint vs) + HoistFailure vs -> throw sid $ EscapedNameErr $ map pprint vs -- NameBinder has no free vars, so there's no risk associated with hoisting. -- The scope is completely distinct, so their exchange doesn't create any accidental @@ -2868,6 +2892,10 @@ abstractFreeVarsNoAnn vs e = Abs bs e' -> Abs bs' e' where bs' = fmapNest (\(b:>UnitE) -> b) bs +unsafeFromNest :: Nest b n l -> [b UnsafeS UnsafeS] +unsafeFromNest Empty = [] +unsafeFromNest (Nest b rest) = unsafeCoerceB b : unsafeFromNest rest + instance Color c => HoistableB (NameBinder c) where freeVarsB _ = mempty @@ -2888,6 +2916,9 @@ instance HoistableB UnitB where instance HoistableE e => HoistableE (ListE e) where freeVarsE (ListE xs) = foldMap freeVarsE xs +instance HoistableE e => HoistableE (RListE e) where + freeVarsE (RListE xs) = foldMap freeVarsE xs + -- === environments === -- The `Subst` type is purely an optimization. We could do everything using @@ -3051,6 +3082,12 @@ toSubstPairs (UnsafeMakeSubst m) = data WithRenamer e i o where WithRenamer :: SubstFrag Name i i' o -> e i' -> WithRenamer e i o +instance Category UnitB where + id = UnitB + {-# INLINE id #-} + UnitB . UnitB = UnitB + {-# INLINE (.) #-} + instance Category (Nest b) where id = Empty {-# INLINE id #-} @@ -3361,6 +3398,13 @@ hoistNameMap b = ignoreHoistFailure . hoistNameMapE b unsafeCoerceIRE :: forall (r'::IR) (r::IR) (e::IR->E) (n::S). e r n -> e r' n unsafeCoerceIRE = TrulyUnsafe.unsafeCoerce +-- === Pretty instances === + +instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty + +instance PrettyE ann => Pretty (BinderP c ann n l) + where pretty (b:>ty) = pretty b <> ":" <> pretty ty + -- === notes === {- diff --git a/src/lib/OccAnalysis.hs b/src/lib/OccAnalysis.hs index 2e9b3d9aa..0e75165be 100644 --- a/src/lib/OccAnalysis.hs +++ b/src/lib/OccAnalysis.hs @@ -20,6 +20,7 @@ import Occurrence hiding (Var) import Occurrence qualified as Occ import Types.Core import Types.Primitives +import Types.Top import QueryType -- === External API === @@ -28,14 +29,10 @@ import QueryType -- annotation holding a summary of how that binding is used. It also eliminates -- unused pure bindings as it goes, since it has all the needed information. -analyzeOccurrences :: EnvReader m => STopLam n -> m n (STopLam n) -analyzeOccurrences = liftLamExpr analyzeOccurrencesBlock +analyzeOccurrences :: EnvReader m => TopLam SimpIR n -> m n (TopLam SimpIR n) +analyzeOccurrences lam = liftLamExpr lam \e -> liftOCCM $ occ accessOnce e {-# INLINE analyzeOccurrences #-} -analyzeOccurrencesBlock :: EnvReader m => SBlock n -> m n (SBlock n) -analyzeOccurrencesBlock = liftOCCM . occNest accessOnce -{-# SCC analyzeOccurrencesBlock #-} - -- === Overview === -- We analyze every binding in the program for occurrence information, @@ -198,17 +195,18 @@ summaryExpr = \case summary :: SAtom n -> OCCM n (IxExpr n) summary atom = case atom of - Var v -> ixExpr $ atomVarName v - Con c -> constructor c - _ -> unknown atom + Stuck _ stuck -> case stuck of + Var v -> ixExpr $ atomVarName v + _ -> unknown atom + Con c -> case c of + -- TODO Represent the actual literal value? + Lit _ -> return $ Deterministic [] + ProdCon elts -> Product <$> mapM summary elts + SumCon _ tag payload -> Inject tag <$> summary payload + HeapVal -> invalid "HeapVal" + DepPair _ _ _ -> error "not implemented" where invalid tag = error $ "Unexpected indexing by " ++ tag - constructor = \case - -- TODO Represent the actual literal value? - Lit _ -> return $ Deterministic [] - ProdCon elts -> Product <$> mapM summary elts - SumCon _ tag payload -> Inject tag <$> summary payload - HeapVal -> invalid "HeapVal" unknown :: HoistableE e => e n -> OCCM n (IxExpr n) unknown _ = return IxAll @@ -248,16 +246,26 @@ class HasOCC (e::E) where occ :: Access n -> e n -> OCCM n (e n) instance HasOCC SAtom where + occ a = \case + Stuck t e -> Stuck <$> occ a t <*> occ a e + Con con -> liftM Con $ runOCCMVisitor a $ visitGeneric con + +instance HasOCC SStuck where occ a = \case Var (AtomVar n ty) -> do modify (<> FV (singletonNameMapE n $ AccessInfo One a)) ty' <- occTy ty return $ Var (AtomVar n ty') - ProjectElt t i x -> ProjectElt <$> occ a t <*> pure i <*> occ a x - atom -> runOCCMVisitor a $ visitAtomPartial atom + StuckProject i x -> StuckProject <$> pure i <*> occ a x + StuckTabApp array ixs -> do + (a', ixs') <- occIdx a ixs + array' <- occ a' array + return $ StuckTabApp array' ixs' + PtrVar t p -> return $ PtrVar t p + RepValAtom x -> return $ RepValAtom x instance HasOCC SType where - occ a ty = runOCCMVisitor a $ visitTypePartial ty + occ a (TyCon con) = liftM TyCon $ runOCCMVisitor a $ visitGeneric con -- TODO What, actually, is the right thing to do for type annotations? Do we -- want a rule like "we never inline into type annotations", or such? For @@ -268,7 +276,7 @@ occTy ty = occ accessOnce ty instance HasOCC SLam where occ a (LamExpr bs body) = do lam@(LamExpr bs' _) <- refreshAbs (Abs bs body) \bs' body' -> - LamExpr bs' <$> occNest (sink a) body' + LamExpr bs' <$> occ (sink a) body' countFreeVarsAsOccurrencesB bs' return lam @@ -290,11 +298,11 @@ instance HasOCC (EffTy SimpIR) where return $ EffTy effs ty' data ElimResult (n::S) where - ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n - ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SAtom l -> ElimResult n + ElimSuccess :: Abs (Nest SDecl) SExpr n -> ElimResult n + ElimFailure :: SDecl n l -> UsageInfo -> Abs (Nest SDecl) SExpr l -> ElimResult n -occNest :: Access n -> Abs (Nest SDecl) SAtom n - -> OCCM n (Abs (Nest SDecl) SAtom n) +occNest :: Access n -> Abs (Nest SDecl) SExpr n + -> OCCM n (Abs (Nest SDecl) SExpr n) occNest a (Abs decls ans) = case decls of Empty -> Abs Empty <$> occ a ans Nest d@(Let _ binding) ds -> do @@ -354,11 +362,15 @@ instance HasOCC (DeclBinding SimpIR) where instance HasOCC SExpr where occ a = \case - TabApp t array ixs -> do + Block effTy (Abs decls ans) -> do + effTy' <- occ a effTy + Abs decls' ans' <- occNest a (Abs decls ans) + return $ Block effTy' (Abs decls' ans') + TabApp t array ix -> do t' <- occTy t - (a', ixs') <- go a ixs + (a', ix') <- occIdx a ix array' <- occ a' array - return $ TabApp t' array' ixs' + return $ TabApp t' array' ix' Case scrut alts (EffTy effs ty) -> do scrut' <- occ accessOnce scrut scrutIx <- summary scrut @@ -372,12 +384,11 @@ instance HasOCC SExpr where ref' <- occ a ref PrimOp . RefOp ref' <$> occ a op expr -> occGeneric a expr - where - go acc [] = return (acc, []) - go acc (ix:ixs) = do - (acc', ixs') <- go acc ixs - (summ, ix') <- occurrenceAndSummary ix - return (location summ acc', ix':ixs') + +occIdx :: Access n -> SAtom n -> OCCM n (Access n, SAtom n) +occIdx acc ix = do + (summ, ix') <- occurrenceAndSummary ix + return (location summ acc, ix') -- Arguments: Usage of the return value, summary of the scrutinee, the -- alternative itself. @@ -391,7 +402,7 @@ occAlt acc scrut alt = do -- case statement in that event. scrutIx <- unknown $ sink scrut extend nb scrutIx do - body' <- occNest (sink acc) body + body' <- occ (sink acc) body return $ Abs b body' ty' <- occTy ty return $ Abs (b':>ty') body' @@ -411,10 +422,10 @@ instance HasOCC (Hof SimpIR) where ixDict' <- inlinedLater ixDict occWithBinder (Abs b body) \b' body' -> do extend b' (Occ.Var $ binderName b') do - body'' <- censored (abstractFor b') (occNest accessOnce body') + body'' <- censored (abstractFor b') (occ accessOnce body') return $ For ann ixDict' (UnaryLamExpr b' body'') For _ _ _ -> error "For body should be a unary lambda expression" - While body -> While <$> censored useManyTimes (occNest accessOnce body) + While body -> While <$> censored useManyTimes (occ accessOnce body) RunReader ini bd -> do iniIx <- summary ini bd' <- oneShot a [Deterministic [], iniIx] bd @@ -451,7 +462,7 @@ instance HasOCC (Hof SimpIR) where return $ RunState Nothing ini' bd' RunState (Just _) _ _ -> error "Expecting to do occurrence analysis before destination passing." - RunIO bd -> RunIO <$> occNest a bd + RunIO bd -> RunIO <$> occ a bd RunInit _ -> -- Though this is probably not too hard to implement. Presumably -- the lambda is one-shot. @@ -459,7 +470,7 @@ instance HasOCC (Hof SimpIR) where oneShot :: Access n -> [IxExpr n] -> LamExpr SimpIR n -> OCCM n (LamExpr SimpIR n) oneShot acc [] (LamExpr Empty body) = - LamExpr Empty <$> occNest acc body + LamExpr Empty <$> occ acc body oneShot acc (ix:ixs) (LamExpr (Nest b bs) body) = do occWithBinder (Abs b (LamExpr bs body)) \b' restLam -> extend b' (sink ix) do diff --git a/src/lib/Occurrence.hs b/src/lib/Occurrence.hs index 5e024e854..ea8248de8 100644 --- a/src/lib/Occurrence.hs +++ b/src/lib/Occurrence.hs @@ -19,6 +19,7 @@ import Data.List (foldl') import Data.Store (Store (..)) import GHC.Generics (Generic (..)) +import PPrint import IRVariants import Name @@ -888,3 +889,15 @@ instance RenameE AccessInfo instance Hashable UsageInfo instance Store UsageInfo + +-- === instances === + +instance Pretty UsageInfo where + pretty (UsageInfo static (ixDepth, ct)) = + "occurs in" <+> pretty static <+> "places, read" + <+> pretty ct <+> "times, to depth" <+> pretty (show ixDepth) + +instance Pretty Count where + pretty = \case + Bounded ct -> "<=" <+> pretty ct + Unbounded -> "many" diff --git a/src/lib/Optimize.hs b/src/lib/Optimize.hs index 65d81b043..1ed73ff23 100644 --- a/src/lib/Optimize.hs +++ b/src/lib/Optimize.hs @@ -7,19 +7,15 @@ {-# LANGUAGE UndecidableInstances #-} module Optimize - ( optimize, peepholeOp, peepholeExpr, hoistLoopInvariant, dceTop, foldCast ) where + ( optimize, hoistLoopInvariant, dceTop) where import Data.Functor -import Data.Word -import Data.Bits -import Data.Bits.Floating -import Data.List import Control.Monad import Control.Monad.State.Strict -import GHC.Float import Types.Core import Types.Primitives +import Types.Top import MTL1 import Name import Subst @@ -37,183 +33,15 @@ optimize = dceTop -- Clean up user code >=> dceTop -- Clean up peephole-optimized code after unrolling >=> hoistLoopInvariant --- === Peephole optimizations === - -peepholeOp :: PrimOp SimpIR o -> EnvReaderM o (SExpr o) -peepholeOp op = case op of - MiscOp (CastOp (BaseTy (Scalar sTy)) (Con (Lit l))) -> return $ case foldCast sTy l of - Just l' -> lit l' - Nothing -> noop - -- TODO: Support more unary and binary ops. - BinOp IAdd l r -> return $ case (l, r) of - -- TODO: Shortcut when either side is zero. - (Con (Lit ll), Con (Lit rl)) -> case (ll, rl) of - (Word32Lit lv, Word32Lit lr) -> lit $ Word32Lit $ lv + lr - _ -> noop - _ -> noop - BinOp (ICmp cop) (Con (Lit ll)) (Con (Lit rl)) -> - return $ lit $ Word8Lit $ fromIntegral $ fromEnum $ case (ll, rl) of - (Int32Lit lv, Int32Lit rv) -> cmp cop lv rv - (Int64Lit lv, Int64Lit rv) -> cmp cop lv rv - (Word8Lit lv, Word8Lit rv) -> cmp cop lv rv - (Word32Lit lv, Word32Lit rv) -> cmp cop lv rv - (Word64Lit lv, Word64Lit rv) -> cmp cop lv rv - _ -> error "Ill typed ICmp?" - BinOp (FCmp cop) (Con (Lit ll)) (Con (Lit rl)) -> - return $ lit $ Word8Lit $ fromIntegral $ fromEnum $ case (ll, rl) of - (Float32Lit lv, Float32Lit rv) -> cmp cop lv rv - (Float64Lit lv, Float64Lit rv) -> cmp cop lv rv - _ -> error "Ill typed FCmp?" - BinOp BOr (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) -> - return $ lit $ Word8Lit $ lv .|. rv - BinOp BAnd (Con (Lit (Word8Lit lv))) (Con (Lit (Word8Lit rv))) -> - return $ lit $ Word8Lit $ lv .&. rv - MiscOp (ToEnum ty (Con (Lit (Word8Lit tag)))) -> case ty of - SumTy cases -> return $ Atom $ SumVal cases (fromIntegral tag) UnitVal - _ -> error "Ill typed ToEnum?" - MiscOp (SumTag (SumVal _ tag _)) -> return $ lit $ Word8Lit $ fromIntegral tag - _ -> return noop - where - noop = PrimOp op - lit = Atom . Con . Lit - - cmp :: Ord a => CmpOp -> a -> a -> Bool - cmp = \case - Less -> (<) - Greater -> (>) - Equal -> (==) - LessEqual -> (<=) - GreaterEqual -> (>=) - -foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal -foldCast sTy l = case sTy of - -- TODO: Check that the casts relating to floating-point agree with the - -- runtime behavior. The runtime is given by the `ICastOp` case in - -- ImpToLLVM.hs. We should make sure that the Haskell functions here - -- produce bitwise identical results to those instructions, by adjusting - -- either this or that as called for. - -- TODO: Also implement casts that may have unrepresentable results, i.e., - -- casting floating-point numbers to smaller floating-point numbers or to - -- fixed-point. Both of these necessarily have a much smaller dynamic range. - Int32Type -> case l of - Int32Lit _ -> Just l - Int64Lit i -> Just $ Int32Lit $ fromIntegral i - Word8Lit i -> Just $ Int32Lit $ fromIntegral i - Word32Lit i -> Just $ Int32Lit $ fromIntegral i - Word64Lit i -> Just $ Int32Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Int64Type -> case l of - Int32Lit i -> Just $ Int64Lit $ fromIntegral i - Int64Lit _ -> Just l - Word8Lit i -> Just $ Int64Lit $ fromIntegral i - Word32Lit i -> Just $ Int64Lit $ fromIntegral i - Word64Lit i -> Just $ Int64Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word8Type -> case l of - Int32Lit i -> Just $ Word8Lit $ fromIntegral i - Int64Lit i -> Just $ Word8Lit $ fromIntegral i - Word8Lit _ -> Just l - Word32Lit i -> Just $ Word8Lit $ fromIntegral i - Word64Lit i -> Just $ Word8Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word32Type -> case l of - Int32Lit i -> Just $ Word32Lit $ fromIntegral i - Int64Lit i -> Just $ Word32Lit $ fromIntegral i - Word8Lit i -> Just $ Word32Lit $ fromIntegral i - Word32Lit _ -> Just l - Word64Lit i -> Just $ Word32Lit $ fromIntegral i - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Word64Type -> case l of - Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) - Int64Lit i -> Just $ Word64Lit $ fromIntegral i - Word8Lit i -> Just $ Word64Lit $ fromIntegral i - Word32Lit i -> Just $ Word64Lit $ fromIntegral i - Word64Lit _ -> Just l - Float32Lit _ -> Nothing - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Float32Type -> case l of - Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Word8Lit i -> Just $ Float32Lit $ fromIntegral i - Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i - Float32Lit _ -> Just l - Float64Lit _ -> Nothing - PtrLit _ _ -> Nothing - Float64Type -> case l of - Int32Lit i -> Just $ Float64Lit $ fromIntegral i - Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i - Word8Lit i -> Just $ Float64Lit $ fromIntegral i - Word32Lit i -> Just $ Float64Lit $ fromIntegral i - Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i - Float32Lit f -> Just $ Float64Lit $ float2Double f - Float64Lit _ -> Just l - PtrLit _ _ -> Nothing - where - -- When casting an integer type to a floating-point type of lower precision - -- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds - -- toward zero, instead of rounding to nearest even like everybody else. - -- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231. - -- - -- We patch this by manually checking the two adjacent floats to the - -- candidate answer, and using one of those if the reverse cast is closer - -- to the original input. - -- - -- This rounds to nearest. We round to nearest *even* by considering the - -- candidates in decreasing order of the number of trailing zeros they - -- exhibit when cast back to the original integer type. - fixUlp :: forall a b w. (Num a, Integral a, FiniteBits a, RealFrac b, FloatingBits b w) - => a -> b -> b - fixUlp orig candidate = res where - res = closest $ sortBy moreLowBits [candidate, candidatem1, candidatep1] - candidatem1 = nextDown candidate - candidatep1 = nextUp candidate - closest = minimumBy (\ca cb -> err ca `compare` err cb) - err cand = absdiff orig (round cand) - absdiff a b = if a >= b then a - b else b - a - moreLowBits a b = - compare (0 - countTrailingZeros (round @b @a a)) - (0 - countTrailingZeros (round @b @a b)) - -peepholeExpr :: SExpr o -> EnvReaderM o (SExpr o) -peepholeExpr expr = case expr of - PrimOp op -> peepholeOp op - TabApp _ (Var (AtomVar t _)) [IdxRepVal ord] -> - lookupAtomName t <&> \case - LetBound (DeclBinding ann (TabCon Nothing tabTy elems)) - | ann /= NoInlineLet && isFinTabTy tabTy-> - -- It is not safe to assume that this index can always be simplified! - -- For example, it might be coming from an unsafe_from_ordinal that is - -- under a case branch that would be dead for all invalid indices. - if 0 <= ord && fromIntegral ord < length elems - then Atom $ elems !! fromIntegral ord - else expr - _ -> expr - -- TODO: Apply a function to literals when it has a cheap body? - -- Think, partial evaluation of threefry. - _ -> return expr - where isFinTabTy = \case - TabPi (TabPiType (IxDictRawFin _) _ _) -> True - _ -> False - -- === Loop unrolling === unrollLoops :: EnvReader m => STopLam n -> m n (STopLam n) -unrollLoops = liftLamExpr unrollLoopsBlock +unrollLoops lam = liftLamExpr lam unrollLoopsExpr {-# SCC unrollLoops #-} -unrollLoopsBlock :: EnvReader m => SBlock n -> m n (SBlock n) -unrollLoopsBlock b = liftM fst $ - liftBuilder $ runStateT1 (runSubstReaderT idSubst (runULM $ ulBlock b)) (ULS 0) +unrollLoopsExpr :: EnvReader m => SExpr n -> m n (SExpr n) +unrollLoopsExpr b = liftM fst $ + liftBuilder $ runStateT1 (runSubstReaderT idSubst (runULM $ buildBlock $ ulExpr b)) (ULS 0) newtype ULS n = ULS Int deriving Show newtype ULM i o a = ULM { runULM :: SubstReaderT AtomSubstVal (StateT1 ULS (BuilderM SimpIR)) i o a} @@ -236,31 +64,25 @@ instance Visitor (ULM i o) SimpIR i o where instance ExprVisitorEmits (ULM i o) SimpIR i o where visitExprEmits = ulExpr -ulBlock :: SBlock i -> ULM i o (SBlock o) -ulBlock b = buildBlock $ visitBlockEmits b - -emitSubstBlock :: Emits o => SBlock i -> ULM i o (SAtom o) -emitSubstBlock (Abs decls ans) = visitDeclsEmits decls $ visitAtom ans - -- TODO: Refine the cost accounting so that operations that will become -- constant-foldable after inlining don't count towards it. ulExpr :: Emits o => SExpr i -> ULM i o (SAtom o) ulExpr expr = case expr of PrimOp (Hof (TypedHof _ (For Fwd ixTy body))) -> case ixTypeDict ixTy of - IxDictRawFin (IdxRepVal n) -> do + DictCon (IxRawFin (IdxRepVal n)) -> do (body', bodyCost) <- withLocalAccounting $ visitLamEmits body -- We add n (in the form of (... + 1) * n) for the cost of the TabCon reconstructing the result. case (bodyCost + 1) * (fromIntegral n) <= unrollBlowupThreshold of True -> case body' of UnaryLamExpr b' block' -> do vals <- dropSubst $ forM (iota n) \i -> do - extendSubst (b' @> SubstVal (IdxRepVal i)) $ emitSubstBlock block' + extendSubst (b' @> SubstVal (IdxRepVal i)) $ ulExpr block' inc $ fromIntegral n -- To account for the TabCon we emit below getLamExprType body' >>= \case PiType (UnaryNest (tb:>_)) (EffTy _ valTy) -> do - let tabTy = TabPi $ TabPiType (IxDictRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy - emitExpr $ TabCon Nothing tabTy vals + let tabTy = toType $ TabPiType (DictCon $ IxRawFin (IdxRepVal n)) (tb:>IdxRepTy) valTy + emit $ TabCon Nothing tabTy vals _ -> error "Expected `for` body to have a Pi type" _ -> error "Expected `for` body to be a lambda expression" False -> do @@ -270,11 +92,11 @@ ulExpr expr = case expr of _ -> nothingSpecial -- Avoid unrolling loops with large table literals TabCon _ _ els -> inc (length els) >> nothingSpecial + Block _ (Abs decls body) -> visitDeclsEmits decls $ ulExpr body _ -> nothingSpecial where inc i = modify \(ULS n) -> ULS (n + i) - nothingSpecial = inc 1 >> (visitGeneric expr >>= liftEnvReaderM . peepholeExpr) - >>= emitExprToAtom + nothingSpecial = inc 1 >> visitGeneric expr >>= emit unrollBlowupThreshold = 12 withLocalAccounting m = do oldCost <- get @@ -301,55 +123,56 @@ instance Visitor (LICMM i o) SimpIR i o where instance ExprVisitorEmits (LICMM i o) SimpIR i o where visitExprEmits = licmExpr -hoistLoopInvariantBlock :: EnvReader m => SBlock n -> m n (SBlock n) -hoistLoopInvariantBlock body = liftLICMM $ buildBlock $ visitBlockEmits body -{-# SCC hoistLoopInvariantBlock #-} +hoistLoopInvariantExpr :: EnvReader m => SExpr n -> m n (SExpr n) +hoistLoopInvariantExpr body = liftLICMM $ buildBlock $ visitExprEmits body +{-# SCC hoistLoopInvariantExpr #-} hoistLoopInvariant :: EnvReader m => STopLam n -> m n (STopLam n) -hoistLoopInvariant = liftLamExpr hoistLoopInvariantBlock +hoistLoopInvariant lam = liftLamExpr lam hoistLoopInvariantExpr {-# INLINE hoistLoopInvariant #-} licmExpr :: Emits o => SExpr i -> LICMM i o (SAtom o) licmExpr = \case - PrimOp (DAMOp (Seq _ dir ix (ProdVal dests) (LamExpr (UnaryNest b) body))) -> do + PrimOp (DAMOp (Seq _ dir ix (Con (ProdCon dests)) (LamExpr (UnaryNest b) body))) -> do ix' <- substM ix dests' <- mapM visitAtom dests let numCarriesOriginal = length dests' Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do -- First, traverse the block, to allow any Hofs inside it to hoist their own decls. - Abs decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildScoped $ visitExprEmits body -- Now, we process the decls and decide which ones to hoist. liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans PairE (ListE extraDests) ab <- emitDecls $ Abs hdecls destsAndBody extraDests' <- mapM toAtomVar extraDests -- Append the destinations of hoisted Allocs as loop carried values. - let dests'' = ProdVal $ dests' ++ (Var <$> extraDests') + let dests'' = Con $ ProdCon $ dests' ++ (toAtom <$> extraDests') let carryTy = getType dests'' let lbTy = case ix' of IxType ixTy _ -> PairTy ixTy carryTy extraDestsTyped <- forM extraDests' \(AtomVar d t) -> return (d, t) Abs extraDestBs (Abs lb bodyAbs) <- return $ abstractFreeVars extraDestsTyped ab body' <- withFreshBinder noHint lbTy \lb' -> do - (oldIx, allCarries) <- fromPair $ Var $ binderVar lb' - (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpacked allCarries - let oldLoopBinderVal = PairVal oldIx (ProdVal oldCarries) + (oldIx, allCarries) <- fromPairReduced $ toAtom $ binderVar lb' + (oldCarries, newCarries) <- splitAt numCarriesOriginal <$> getUnpackedReduced allCarries + let oldLoopBinderVal = Con $ ProdCon [oldIx, Con $ ProdCon oldCarries] let s = extraDestBs @@> map SubstVal newCarries <.> lb @> SubstVal oldLoopBinderVal - block <- applySubst s bodyAbs + block <- mkBlock =<< applySubst s bodyAbs return $ UnaryLamExpr lb' block emitSeq dir ix' dests'' body' PrimOp (Hof (TypedHof _ (For dir ix (LamExpr (UnaryNest b) body)))) -> do ix' <- substM ix Abs hdecls destsAndBody <- visitBinders (UnaryNest b) \(UnaryNest b') -> do - Abs decls ans <- buildBlock $ visitBlockEmits body + Abs decls ans <- buildScoped $ visitExprEmits body liftEnvReaderM $ runSubstReaderT idSubst $ seqLICM REmpty mempty (asNameBinder b') REmpty decls ans PairE (ListE []) (Abs lnb bodyAbs) <- emitDecls $ Abs hdecls destsAndBody ixTy <- substM $ binderType b body' <- withFreshBinder noHint ixTy \i -> do - block <- applyRename (lnb@>binderName i) bodyAbs + block <- mkBlock =<< applyRename (lnb@>binderName i) bodyAbs return $ UnaryLamExpr i block emitHof $ For dir ix' body' - expr -> visitGeneric expr >>= emitExpr + Block _ (Abs decls result) -> visitDeclsEmits decls $ licmExpr result + expr -> visitGeneric expr >>= emit seqLICM :: RNest SDecl n1 n2 -- hoisted decls -> [SAtomName n2] -- hoisted dests @@ -401,12 +224,12 @@ newtype DCEM n a = DCEM { runDCEM :: StateT1 FV EnvReaderM n a } , MonadState (FV n), EnvExtender) dceTop :: EnvReader m => STopLam n -> m n (STopLam n) -dceTop = liftLamExpr dceBlock +dceTop lam = liftLamExpr lam dceExpr {-# INLINE dceTop #-} -dceBlock :: EnvReader m => SBlock n -> m n (SBlock n) -dceBlock b = liftEnvReaderM $ evalStateT1 (runDCEM $ dceBlock' b) mempty -{-# SCC dceBlock #-} +dceExpr :: EnvReader m => SExpr n -> m n (SExpr n) +dceExpr b = liftEnvReaderM $ evalStateT1 (runDCEM $ dce b) mempty +{-# SCC dceExpr #-} class HasDCE (e::E) where dce :: e n -> DCEM n (e n) @@ -424,18 +247,37 @@ instance Color c => HasDCE (Name c) where dce n = modify (<> FV (freeVarsE n)) $> n instance HasDCE SAtom where - dce = \case - Var n -> modify (<> FV (freeVarsE n)) $> Var n - ProjectElt t i x -> ProjectElt <$> dce t <*> pure i <*> dce x - atom -> visitAtomPartial atom + dce atom = case atom of + Stuck _ _ -> modify (<> FV (freeVarsE atom)) $> atom + Con con -> Con <$> visitGeneric con + +instance HasDCE SType where + dce (TyCon e) = TyCon <$> visitGeneric e -instance HasDCE SType where dce = visitTypePartial instance HasDCE (PiType SimpIR) where dce (PiType bs effTy) = do dceBinders bs effTy \bs' effTy' -> PiType bs' <$> dce effTy' instance HasDCE (LamExpr SimpIR) where - dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dceBlock' e' + dce (LamExpr bs e) = dceBinders bs e \bs' e' -> LamExpr bs' <$> dce e' + +instance HasDCE (Expr SimpIR) where + dce = \case + Block effTy block -> do + -- The free vars accumulated in the state of DCEM should correspond to + -- the free vars of the Abs of the block answer, by the decls traversed + -- so far. dceNest takes care to uphold this invariant, but we temporarily + -- reset the state to an empty map, just so that names from the surrounding + -- block don't end up influencing elimination decisions here. Note that we + -- restore the state (and accumulate free vars of the DCE'd block into it) + -- right after dceNest. + effTy' <- dce effTy + old <- get + put mempty + block' <- dceBlock block + modify (<> old) + return $ Block effTy' block' + e -> visitGeneric e dceBinders :: (HoistableB b, BindsEnv b, RenameB b, RenameE e) @@ -448,21 +290,6 @@ dceBinders b e cont = do return ans {-# INLINE dceBinders #-} -dceBlock' :: SBlock n -> DCEM n (SBlock n) -dceBlock' (Abs decls ans) = do - -- The free vars accumulated in the state of DCEM should correspond to - -- the free vars of the Abs of the block answer, by the decls traversed - -- so far. dceNest takes care to uphold this invariant, but we temporarily - -- reset the state to an empty map, just so that names from the surrounding - -- block don't end up influencing elimination decisions here. Note that we - -- restore the state (and accumulate free vars of the DCE'd block into it) - -- right after dceNest. - old <- get - put mempty - block <- dceNest decls ans - modify (<> old) - return block - wrapWithCachedFVs :: HoistableE e => e n -> DCEM n (CachedFVs e n) wrapWithCachedFVs e = do FV fvs <- get @@ -483,11 +310,11 @@ hoistUsingCachedFVs :: (BindsNames b, HoistableE e) => hoistUsingCachedFVs b e = hoistViaCachedFVs b <$> wrapWithCachedFVs e data ElimResult n where - ElimSuccess :: Abs (Nest SDecl) SAtom n -> ElimResult n - ElimFailure :: SDecl n l -> Abs (Nest SDecl) SAtom l -> ElimResult n + ElimSuccess :: SBlock n -> ElimResult n + ElimFailure :: SDecl n l -> SBlock l -> ElimResult n -dceNest :: Nest SDecl n l -> SAtom l -> DCEM n (Abs (Nest SDecl) SAtom n) -dceNest decls ans = case decls of +dceBlock :: SBlock n -> DCEM n (SBlock n) +dceBlock (Abs decls ans) = case decls of Empty -> Abs Empty <$> dce ans Nest b@(Let _ decl) bs -> do -- Note that we only ever dce the abs below under this refreshAbs, @@ -495,7 +322,7 @@ dceNest decls ans = case decls of -- because refreshAbs of StateT1 triggers hoistState, which we -- implement by deleting the entries that can't hoist). dceAttempt <- refreshAbs (Abs b (Abs bs ans)) \b' (Abs bs' ans') -> do - below <- dceNest bs' ans' + below <- dceBlock $ Abs bs' ans' case isPure decl of False -> return $ ElimFailure b' below True -> do @@ -504,11 +331,10 @@ dceNest decls ans = case decls of HoistFailure _ -> ElimFailure b' below case dceAttempt of ElimSuccess below' -> return below' - ElimFailure (Let b' decl') (Abs bs'' ans'') -> do - decl'' <- dce decl' + ElimFailure (Let b' (DeclBinding ann expr)) (Abs bs'' ans'') -> do + expr' <- dce expr modify (<>FV (freeVarsB b')) - return $ Abs (Nest (Let b' decl'') bs'') ans'' + return $ Abs (Nest (Let b' (DeclBinding ann expr')) bs'') ans'' instance HasDCE (EffectRow SimpIR) -instance HasDCE (DeclBinding SimpIR) instance HasDCE (EffTy SimpIR) diff --git a/src/lib/PPrint.hs b/src/lib/PPrint.hs index 23bc7ea60..b16559fa5 100644 --- a/src/lib/PPrint.hs +++ b/src/lib/PPrint.hs @@ -6,43 +6,34 @@ {-# LANGUAGE IncoherentInstances #-} -- due to `ConRef` {-# LANGUAGE UndecidableInstances #-} -{-# OPTIONS_GHC -Wno-orphans #-} module PPrint ( - pprint, pprintCanonicalized, pprintList, asStr , atPrec, toJSONStr, - PrettyPrec(..), PrecedenceLevel (..), prettyBlock, printLitBlock, - printResult, prettyFromPrettyPrec) where + Pretty (..), Doc, DocPrec, (<+>), pprint, pprintList, asStr , atPrec, + pAppArg, pApp, pArg, hardline, PrettyPrec(..), PrecedenceLevel (..), + docAsStr, parensSep, prettyLines, sep, pLowest, prettyFromPrettyPrec, + indented, commaSep, spaced, spaceIfColinear, encloseSep) where -import Data.Aeson hiding (Result, Null, Value, Success) -import GHC.Exts (Constraint) -import GHC.Float import Data.Foldable (toList, fold) -import qualified Data.ByteString.Lazy.Char8 as B -import qualified Data.Map.Strict as M import Data.Text.Prettyprint.Doc.Render.Text import Data.Text.Prettyprint.Doc -import Data.Text (Text, snoc, uncons, unsnoc, unpack) -import qualified Data.Set as S -import Data.String (fromString) -import qualified System.Console.ANSI as ANSI -import System.Console.ANSI hiding (Color) +import Data.Text (unpack) import System.IO.Unsafe import qualified System.Environment as E -import Numeric -import ConcreteSyntax -import Err -import IRVariants -import Name -import Occurrence (Count (Bounded), UsageInfo (..)) -import Occurrence qualified as Occ -import Types.Core -import Types.Imp -import Types.Misc -import Types.Primitives -import Types.Source -import QueryTypePure -import Util (Tree (..)) +-- === small pretty-printing utils === + +pprint :: Pretty a => a -> String +pprint x = docAsStr $ pretty x +{-# SCC pprint #-} + +docAsStr :: Doc ann -> String +docAsStr doc = unpack $ renderStrict $ layoutPretty layout $ doc + +layout :: LayoutOptions +layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions + where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" + +-- === DocPrec === -- A DocPrec is a slightly context-aware Doc, specifically one that -- knows the precedence level of the immediately enclosing operation, @@ -96,31 +87,12 @@ prettyFromPrettyPrec = pArg pAppArg :: (PrettyPrec a, Foldable f) => Doc ann -> f a -> Doc ann pAppArg name as = align $ name <> group (nest 2 $ foldMap (\a -> line <> pArg a) as) -fromInfix :: Text -> Maybe Text -fromInfix t = do - ('(', t') <- uncons t - (t'', ')') <- unsnoc t' - return t'' - -type PrettyPrecE e = (forall (n::S) . PrettyPrec (e n )) :: Constraint -type PrettyPrecB b = (forall (n::S) (l::S). PrettyPrec (b n l)) :: Constraint - -pprintCanonicalized :: (HoistableE e, RenameE e, PrettyE e) => e n -> String -pprintCanonicalized e = canonicalizeForPrinting e \e' -> pprint e' - pprintList :: Pretty a => [a] -> String -pprintList xs = asStr $ vsep $ punctuate "," (map p xs) - -layout :: LayoutOptions -layout = if unbounded then LayoutOptions Unbounded else defaultLayoutOptions - where unbounded = unsafePerformIO $ (Just "1"==) <$> E.lookupEnv "DEX_PPRINT_UNBOUNDED" +pprintList xs = asStr $ vsep $ punctuate "," (map pretty xs) asStr :: Doc ann -> String asStr doc = unpack $ renderStrict $ layoutPretty layout $ doc -p :: Pretty a => a -> Doc ann -p = pretty - pLowest :: PrettyPrec a => a -> Doc ann pLowest a = prettyPrec a LowestPrec @@ -130,22 +102,8 @@ pApp a = prettyPrec a AppPrec pArg :: PrettyPrec a => a -> Doc ann pArg a = prettyPrec a ArgPrec -instance IRRep r => Pretty (Block r n) where - pretty (Abs decls expr) = prettyBlock decls expr -instance IRRep r => PrettyPrec (Block r n) where - prettyPrec (Abs decls expr) = atPrec LowestPrec $ prettyBlock decls expr - -prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann -prettyBlock Empty expr = group $ line <> pLowest expr -prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -fromNest :: Nest b n l -> [b UnsafeS UnsafeS] -fromNest Empty = [] -fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest - prettyLines :: (Foldable f, Pretty a) => f a -> Doc ann -prettyLines xs = foldMap (\d -> hardline <> p d) $ toList xs +prettyLines xs = foldMap (\d -> hardline <> pretty d) $ toList xs parensSep :: Doc ann -> [Doc ann] -> Doc ann parensSep separator items = encloseSep "(" ")" separator items @@ -156,994 +114,13 @@ spaceIfColinear = flatAlt "" space instance PrettyPrec a => PrettyPrec [a] where prettyPrec xs = atPrec ArgPrec $ hsep $ map pLowest xs -instance PrettyE ann => Pretty (BinderP c ann n l) - where pretty (b:>ty) = p b <> ":" <> p ty - -instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Expr r n) where - prettyPrec (Atom x) = prettyPrec x - prettyPrec (App _ f xs) = atPrec AppPrec $ pApp f <+> spaced (toList xs) - prettyPrec (TopApp _ f xs) = atPrec AppPrec $ pApp f <+> spaced (toList xs) - prettyPrec (TabApp _ f xs) = atPrec AppPrec $ pApp f <> "." <> dotted (toList xs) - prettyPrec (Case e alts (EffTy effs _)) = prettyPrecCase "case" e alts effs - prettyPrec (TabCon _ _ es) = atPrec ArgPrec $ list $ pApp <$> es - prettyPrec (PrimOp op) = prettyPrec op - prettyPrec (ApplyMethod _ d i xs) = atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs - -prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann -prettyPrecCase name e alts effs = atPrec LowestPrec $ - name <+> pApp e <+> "of" <> - nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts - <> effectLine effs) - where - effectLine :: IRRep r => EffectRow r n -> Doc ann - effectLine Pure = "" - effectLine row = hardline <> "case annotated with effects" <+> p row - -prettyAlt :: IRRep r => Alt r n -> Doc ann -prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (p body) - -prettyBinderNoAnn :: Binder r n l -> Doc ann -prettyBinderNoAnn (b:>_) = p b - -instance (IRRep r, PrettyPrecE e) => Pretty (Abs (Binder r) e n) where pretty = prettyFromPrettyPrec -instance (IRRep r, PrettyPrecE e) => PrettyPrec (Abs (Binder r) e n) where - prettyPrec (Abs binder body) = atPrec LowestPrec $ "\\" <> p binder <> "." <> pLowest body - -instance IRRep r => Pretty (DeclBinding r n) where - pretty (DeclBinding ann expr) = "Decl" <> p ann <+> p expr - -instance IRRep r => Pretty (Decl r n l) where - pretty (Let b (DeclBinding ann rhs)) = - align $ annDoc <> p (b:>getType rhs) <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann - -instance IRRep r => Pretty (PiType r n) where - pretty (PiType bs (EffTy effs resultTy)) = - (spaced $ fromNest $ bs) <+> "->" <+> "{" <> p effs <> "}" <+> p resultTy - -instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (LamExpr r n) where - prettyPrec (LamExpr bs body) = - atPrec LowestPrec $ prettyLam (p bs <> ".") body - -instance IRRep r => Pretty (IxType r n) where - pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict - -instance Pretty (DictExpr n) where - pretty d = case d of - InstanceDict name args -> "Instance" <+> p name <+> p args - InstantiatedGiven v args -> "Given" <+> p v <+> p (toList args) - SuperclassProj d' i -> "SuperclassProj" <+> p d' <+> p i - IxFin n -> "Ix (Fin" <+> p n <> ")" - DataData a -> "Data " <+> p a - -instance IRRep r => Pretty (IxDict r n) where - pretty = \case - IxDictAtom x -> p x - IxDictRawFin n -> "Ix (RawFin " <> p n <> ")" - IxDictSpecialized _ d xs -> p d <+> p xs - -instance Pretty (DictType n) where - pretty (DictType classSourceName _ params) = - p classSourceName <+> spaced params - -instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DepPairType r n) where - prettyPrec (DepPairType _ b rhs) = - atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [p b, p rhs] - -instance Pretty (EffectOpType n) where - pretty (EffectOpType pol ty) = "[" <+> p pol <+> ":" <+> p ty <+> "]" - -instance Pretty (CoreLamExpr n) where - pretty (CoreLamExpr _ lam) = p lam - -instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Atom r n) where - prettyPrec atom = case atom of - Var v -> atPrec ArgPrec $ p v - Lam lam -> atPrec LowestPrec $ p lam - DepPair x y _ -> atPrec ArgPrec $ align $ group $ - parens $ p x <+> ",>" <+> p y - Con e -> prettyPrec e - Eff e -> atPrec ArgPrec $ p e - PtrVar _ v -> atPrec ArgPrec $ p v - DictCon _ d -> atPrec LowestPrec $ p d - RepValAtom x -> atPrec LowestPrec $ pretty x - ProjectElt _ idxs v -> atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v - NewtypeCon con x -> prettyPrecNewtype con x - SimpInCore x -> prettyPrec x - DictHole _ e _ -> atPrec LowestPrec $ "synthesize" <+> pApp e - TypeAsAtom ty -> prettyPrec ty - -instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Type r n) where - prettyPrec = \case - Pi piType -> atPrec LowestPrec $ align $ p piType - TabPi piType -> atPrec LowestPrec $ align $ p piType - DepPairTy ty -> prettyPrec ty - TC e -> prettyPrec e - DictTy t -> atPrec LowestPrec $ p t - NewtypeTyCon con -> prettyPrec con - TyVar v -> atPrec ArgPrec $ p v - ProjectEltTy _ idxs v -> - atPrec LowestPrec $ "ProjectElt" <+> p idxs <+> p v - -instance Pretty (SimpInCore n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (SimpInCore n) where - prettyPrec = \case - LiftSimp ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" - LiftSimpFun ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" - ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts - TabLam _ _ -> atPrec AppPrec $ "tablam" - -instance IRRep r => Pretty (RepVal r n) where - pretty (RepVal ty tree) = " p tree <+> ":" <+> p ty <> ">" - -instance Pretty a => Pretty (Tree a) where - pretty = \case - Leaf x -> pretty x - Branch xs -> pretty xs - -instance Pretty Projection where - pretty = \case - UnwrapNewtype -> "u" - ProjectProduct i -> p i - -forStr :: ForAnn -> Doc ann -forStr Fwd = "for" -forStr Rev = "rof" - -instance Pretty (CorePiType n) where - pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = - prettyBindersWithExpl expls bs <+> p appExpl <> prettyEff <> p resultTy - where - prettyEff = case eff of - Pure -> space - _ -> space <> pretty eff <> space - -prettyBindersWithExpl :: forall b n l ann. PrettyB b - => [Explicitness] -> Nest b n l -> Doc ann -prettyBindersWithExpl expls bs = do - let groups = groupByExpl $ zip expls (fromNest bs) - let groups' = case groups of [] -> [(Explicit, [])] - _ -> groups - mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] - -groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] -groupByExpl [] = [] -groupByExpl ((expl, b):bs) = do - let (matches, rest) = span (\(expl', _) -> expl == expl') bs - let matches' = map snd matches - (expl, b:matches') : groupByExpl rest - -withExplParens :: Explicitness -> Doc ann -> Doc ann -withExplParens Explicit x = parens x -withExplParens (Inferred _ Unify) x = braces $ x -withExplParens (Inferred _ (Synth _)) x = brackets x - -instance IRRep r => Pretty (TabPiType r n) where - pretty (TabPiType dict (b:>ty) body) = let - prettyBody = case body of - Pi subpi -> pretty subpi - _ -> pLowest body - prettyBinder = case dict of - IxDictRawFin n -> if binderName b `isFreeIn` body - then parens $ p b <> ":" <> prettyTy - else prettyTy - where prettyTy = "RawFin" <+> p n - _ -> prettyBinderHelper (b:>ty) body - in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) - --- A helper to let us turn dict printing on and off. We mostly want it off to --- reduce clutter in prints and error messages, but when debugging synthesis we --- want it on. -prettyIxDict :: IRRep r => IxDict r n -> Doc ann -prettyIxDict dict = if False then " " <> p dict else mempty - -prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann -prettyBinderHelper (b:>ty) body = - if binderName b `isFreeIn` body - then parens $ p (b:>ty) - else p ty - -prettyLam :: Pretty a => Doc ann -> a -> Doc ann -prettyLam binders body = - group $ group (nest 4 $ binders) <> group (nest 2 $ p body) - -instance IRRep r => Pretty (EffectRow r n) where - pretty (EffectRow effs t) = - braces $ hsep (punctuate "," (map p (eSetToList effs))) <> p t - -instance IRRep r => Pretty (EffectRowTail r n) where - pretty = \case - NoTail -> mempty - EffectRowTail v -> "|" <> p v - -instance IRRep r => Pretty (Effect r n) where - pretty eff = case eff of - RWSEffect rws h -> p rws <+> p h - ExceptionEffect -> "Except" - IOEffect -> "IO" - InitEffect -> "Init" - -instance Pretty (UEffect n) where - pretty eff = case eff of - URWSEffect rws h -> p rws <+> p h - UExceptionEffect -> "Except" - UIOEffect -> "IO" - -instance PrettyPrec (Name s n) where prettyPrec = atPrec ArgPrec . pretty - -instance PrettyPrec (AtomVar r n) where - prettyPrec (AtomVar v _) = prettyPrec v -instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec - -instance IRRep r => Pretty (AtomBinding r n) where - pretty binding = case binding of - LetBound b -> p b - MiscBound t -> p t - SolverBound b -> p b - FFIFunBound s _ -> p s - NoinlineFun ty _ -> "Top function with type: " <+> p ty - TopDataBound (RepVal ty _) -> "Top data with type: " <+> p ty - -instance Pretty (SpecializationSpec n) where - pretty (AppSpecialization f (Abs bs (ListE args))) = - "Specialization" <+> p f <+> p bs <+> p args - -instance Pretty IxMethod where - pretty method = p $ show method - -instance Pretty (SolverBinding n) where - pretty (InfVarBound ty _) = "Inference variable of type:" <+> p ty - pretty (SkolemBound ty ) = "Skolem variable of type:" <+> p ty - -instance Pretty (Binding c n) where - pretty b = case b of - -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` - -- TODO: can we avoid printing needing IRRep? Presumably it's related to - -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. - AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) - TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef - DataConBinding tyConName idx -> "Data constructor:" <+> - pretty tyConName <+> "Constructor index:" <+> pretty idx - ClassBinding classDef -> pretty classDef - InstanceBinding instanceDef _ -> pretty instanceDef - MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className - TopFunBinding f -> pretty f - FunObjCodeBinding _ -> "" - ModuleBinding _ -> "" - PtrBinding _ _ -> "" - SpecializedDictBinding _ -> "" - ImpNameBinding ty -> "Imp name of type: " <+> p ty - -instance Pretty (Module n) where - pretty m = prettyRecord - [ ("moduleSourceName" , p $ moduleSourceName m) - , ("moduleDirectDeps" , p $ S.toList $ moduleDirectDeps m) - , ("moduleTransDeps" , p $ S.toList $ moduleTransDeps m) - , ("moduleExports" , p $ moduleExports m) - , ("moduleSynthCandidates", p $ moduleSynthCandidates m) ] - -instance Pretty (TyConParams n) where - pretty (TyConParams _ _) = undefined - -instance Pretty (TyConDef n) where - pretty (TyConDef name _ bs cons) = "data" <+> p name <+> p bs <> pretty cons - -instance Pretty (DataConDefs n) where - pretty = undefined - -instance Pretty (DataConDef n) where - pretty (DataConDef name _ repTy _) = - p name <+> ":" <+> p repTy - -instance Pretty (ClassDef n) where - pretty (ClassDef classSourceName methodNames _ _ params superclasses methodTys) = - "Class:" <+> pretty classSourceName <+> pretty methodNames - <> indented ( - line <> "parameter binders:" <+> pretty params <> - line <> "superclasses:" <+> pretty superclasses <> - line <> "methods:" <+> pretty methodTys) - -instance Pretty ParamRole where - pretty r = p (show r) - -instance Pretty (InstanceDef n) where - pretty (InstanceDef className _ bs params _) = - "Instance" <+> p className <+> pretty bs <+> p params - -deriving instance (forall c n. Pretty (v c n)) => Pretty (RecSubst v o) - -instance Pretty (TopEnv n) where - pretty (TopEnv defs rules cache _ _) = - prettyRecord [ ("Defs" , p defs) - , ("Rules" , p rules) - , ("Cache" , p cache) ] - -instance Pretty (CustomRules n) where - pretty _ = "TODO: Rule printing" - -instance Pretty (ImportStatus n) where - pretty imports = pretty $ S.toList $ directImports imports - -instance Pretty (ModuleEnv n) where - pretty (ModuleEnv imports sm sc) = - prettyRecord [ ("Imports" , p imports) - , ("Source map" , p sm) - , ("Synth candidates", p sc) ] - -instance Pretty (Env n) where - pretty (Env env1 env2) = - prettyRecord [ ("Top env" , p env1) - , ("Module env", p env2)] - -prettyRecord :: [(String, Doc ann)] -> Doc ann -prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs - -instance Pretty SourceBlock where - pretty block = pretty $ ensureNewline (sbText block) where - -- Force the SourceBlock to end in a newline for echoing, even if - -- it was terminated with EOF in the original program. - ensureNewline t = case unsnoc t of - Nothing -> t - Just (_, '\n') -> t - _ -> t `snoc` '\n' - -prettyDuration :: Double -> Doc ann -prettyDuration d = p (showFFloat (Just 3) (d * mult) "") <+> unit - where (mult, unit) = if d >= 1 then (1 , "s") - else if d >= 1e-3 then (1e3, "ms") - else if d >= 1e-6 then (1e6, "us") - else (1e9, "ns") - -instance Pretty Output where - pretty (TextOut s) = pretty s - pretty (HtmlOut _) = "" - -- pretty (ExportedFun _ _) = "" - pretty (BenchResult name compileTime runTime stats) = - benchName <> hardline <> - "Compile time: " <> prettyDuration compileTime <> hardline <> - "Run time: " <> prettyDuration runTime <+> - (case stats of - Just (runs, _) -> - "\t" <> parens ("based on" <+> p runs <+> plural "run" "runs" runs) - Nothing -> "") - where benchName = case name of "" -> "" - _ -> "\n" <> p name - pretty (PassInfo _ s) = p s - pretty (EvalTime t _) = "Eval (s): " <+> p t - pretty (TotalTime t) = "Total (s): " <+> p t <+> " (eval + compile)" - pretty (MiscLog s) = p s - - -instance Pretty PassName where - pretty x = p $ show x - -instance Pretty Result where - pretty (Result outs r) = vcat (map pretty outs) <> maybeErr - where maybeErr = case r of Failure err -> p err - Success () -> mempty - -instance Pretty (UBinder c n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UBinder c n l) where - prettyPrec b = atPrec ArgPrec case b of - UBindSource _ v -> p v - UIgnore -> "_" - UBind _ v _ -> p v - -instance PrettyE e => Pretty (WithSrcE e n) where - pretty (WithSrcE _ x) = p x - -instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where - prettyPrec (WithSrcE _ x) = prettyPrec x - -instance PrettyB b => Pretty (WithSrcB b n l) where - pretty (WithSrcB _ x) = p x - -instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where - prettyPrec (WithSrcB _ x) = prettyPrec x - -instance PrettyE e => Pretty (SourceNameOr e n) where - pretty (SourceName _ v) = p v - pretty (InternalName _ v _) = p v - -instance Pretty (SourceOrInternalName c n) where - pretty (SourceOrInternalName sn) = p sn - -instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (ULamExpr n) where - prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ - "\\" <> p bs <+> "." <+> indented (p body) - -instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPiExpr n) where - prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> pLowest ty - prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ - p pats <+> p appExpl <+> p eff <+> pLowest ty - -instance Pretty Explicitness where - pretty expl = p (show expl) - -instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UTabPiExpr n) where - prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ - p pat <+> "=>" <+> pLowest ty - -instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UDepPairType n) where - -- TODO: print explicitness info - prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ - p pat <+> "&>" <+> pLowest ty - -instance Pretty (UBlock' n) where - pretty (UBlock decls result) = - prettyLines (fromNest decls) <> hardline <> pLowest result - -instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UExpr' n) where - prettyPrec expr = case expr of - ULit l -> prettyPrec l - UVar v -> atPrec ArgPrec $ p v - ULam lam -> prettyPrec lam - UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named - UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x - UFor dir (UForExpr binder body) -> - atPrec LowestPrec $ kw <+> p binder <> "." - <+> nest 2 (p body) - where kw = case dir of Fwd -> "for" - Rev -> "rof" - UPi piType -> prettyPrec piType - UTabPi piType -> prettyPrec piType - UDepPairTy depPairType -> prettyPrec depPairType - UDepPair lhs rhs -> atPrec ArgPrec $ parens $ - p lhs <+> ",>" <+> p rhs - UHole -> atPrec ArgPrec "_" - UTypeAnn v ty -> atPrec LowestPrec $ - group $ pApp v <> line <> ":" <+> pApp ty - UTabCon xs -> atPrec ArgPrec $ p xs - UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs - UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> - nest 2 (prettyLines alts) - UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f - UNatLit v -> atPrec ArgPrec $ p v - UIntLit v -> atPrec ArgPrec $ p v - UFloatLit v -> atPrec ArgPrec $ p v - UDo block -> atPrec LowestPrec $ p block - -instance Pretty FieldName' where - pretty = \case - FieldName s -> pretty s - FieldNum n -> pretty n - -instance Pretty (UAlt n) where - pretty (UAlt pat body) = p pat <+> "->" <+> p body - -instance Pretty (UTopDecl n l) where - pretty (UDataDefDecl (UDataDef nm (_, bs) dataCons) bTyCon bDataCons) = - "data" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines (zip (toList $ fromNest bDataCons) dataCons)) - pretty (UStructDecl bTyCon (UStructDef nm (_, bs) fields defs)) = - "struct" <+> p bTyCon <+> p nm <+> spaced (fromNest bs) <+> "where" <> nest 2 - (prettyLines fields <> prettyLines defs) - pretty (UInterface params methodTys interfaceName methodNames) = - "interface" <+> p params <+> p interfaceName - <> hardline <> foldMap (<>hardline) methods - where - methods = [ p b <> ":" <> p (unsafeCoerceE ty) - | (b, ty) <- zip (toList $ fromNest methodNames) methodTys] - pretty (UInstance className bs params methods (RightB UnitB) _) = - "instance" <+> p bs <+> p className <+> spaced params <+> - prettyLines methods - pretty (UInstance className bs params methods (LeftB v) _) = - "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params - <> prettyLines methods - pretty (UEffectDecl opTys effName opNames) = - "effect" <+> p effName <> hardline <> foldMap (<>hardline) ops - where ops = [ p pol <+> p b <> ":" <> p (unsafeCoerceE ty) - | (b, UEffectOpType pol ty) <- zip (toList $ fromNest opNames) opTys] - pretty (UHandlerDecl effName bodyTyArg tyArgs retEff retTy opDefs name) = - "handler" <+> p name <+> "of" <+> p effName <+> p bodyTyArg <+> p tyArgs - <+> ":" <+> p retEff <+> p retTy <> hardline - <> foldMap ((<>hardline) . p) opDefs - pretty (ULocalDecl decl) = p decl - -instance Pretty (UDecl' n l) where - pretty (ULet ann b _ rhs) = - align $ p ann <+> p b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) - pretty (UExprDecl expr) = p expr - pretty UPass = "pass" - -instance Pretty (UEffectOpDef n) where - pretty (UEffectOpDef rp n body) = p rp <+> p n <+> "=" <+> p body - pretty (UReturnOpDef body) = "return =" <+> p body - -instance Pretty UResumePolicy where - pretty UNoResume = "jmp" - pretty ULinearResume = "def" - pretty UAnyResume = "ctl" - -instance Pretty (UEffectRow n) where - pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (p <$> toList x) - pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (p <$> toList x)) <+> "|" <+> p y <> "}" - -prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann -prettyBinderNest bs = nest 6 $ line' <> (sep $ map p $ fromNest bs) - -instance Pretty (UDataDefTrail n) where - pretty (UDataDefTrail bs) = p $ fromNest bs - -instance Pretty (UAnnBinder req n l) where - pretty (UAnnBinder b ty cs) = p b <> ":" <> p ty <> printConstraints cs - -printConstraints :: Pretty a => [a] -> Doc ann -printConstraints = \case - [] -> mempty - c:cs -> "|" <> pretty c <> printConstraints cs - -instance Pretty (UAnn req n) where - pretty (UAnn ty) = ":" <> p ty - pretty UNoAnn = mempty - -instance Pretty (UMethodDef' n) where - pretty (UMethodDef b rhs) = p b <+> "=" <+> p rhs - -instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec -instance PrettyPrec (UPat' n l) where - prettyPrec pat = case pat of - UPatBinder x -> atPrec ArgPrec $ p x - UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (fromNest xs) - UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y - UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (fromNest pats) - UPatTable pats -> atPrec ArgPrec $ p pats +instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty spaced :: (Foldable f, Pretty a) => f a -> Doc ann -spaced xs = hsep $ map p $ toList xs - -dotted :: (Foldable f, Pretty a) => f a -> Doc ann -dotted xs = fold $ punctuate "." $ map p $ toList xs +spaced xs = hsep $ map pretty $ toList xs commaSep :: (Foldable f, Pretty a) => f a -> Doc ann -commaSep xs = fold $ punctuate "," $ map p $ toList xs - -instance Pretty (EnvFrag n l) where - pretty (EnvFrag bindings) = p bindings - -instance Pretty (Cache n) where - pretty (Cache _ _ _ _ _ _) = "" -- TODO - -instance Pretty (SynthCandidates n) where - pretty scs = - "lambda dicts:" <+> p (lambdaDicts scs) <> hardline - <> "instance dicts:" <+> p (M.toList $ instanceDicts scs) - -instance Pretty (LoadedModules n) where - pretty _ = "" +commaSep xs = fold $ punctuate "," $ map pretty $ toList xs indented :: Doc ann -> Doc ann indented doc = nest 2 (hardline <> doc) <> hardline - --- ==== Imp IR === - -instance Pretty (IExpr n) where - pretty (ILit v) = p v - pretty (IVar v _) = p v - pretty (IPtrVar v _) = p v - -instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty (ImpDecl n l) where - pretty (ImpLet Empty instr) = p instr - pretty (ImpLet (Nest b Empty) instr) = p b <+> "=" <+> p instr - pretty (ImpLet bs instr) = p bs <+> "=" <+> p instr - -instance Pretty IFunType where - pretty (IFunType cc argTys retTys) = - "Fun" <+> p cc <+> p argTys <+> "->" <+> p retTys - -instance Pretty (TopFunDef n) where - pretty = \case - Specialization s -> p s - LinearizationPrimal _ -> "" - LinearizationTangent _ -> "" - -instance Pretty (TopFun n) where - pretty = \case - DexTopFun def lam lowering -> - "Top-level Function" - <> hardline <+> "definition:" <+> pretty def - <> hardline <+> "lambda:" <+> pretty lam - <> hardline <+> "lowering:" <+> pretty lowering - FFITopFun f _ -> p f - -instance IRRep r => Pretty (TopLam r n) where - pretty (TopLam _ _ lam) = pretty lam - -instance Pretty a => Pretty (EvalStatus a) where - pretty = \case - Waiting -> "" - Running -> "" - Finished a -> pretty a - -instance Pretty (ImpFunction n) where - pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = - "impfun" <+> p cc <+> prettyBinderNest bs - <> nest 2 (hardline <> p body) <> hardline - -instance Pretty (ImpBlock n) where - pretty (ImpBlock Empty []) = mempty - pretty (ImpBlock Empty expr) = group $ line <> pLowest expr - pretty (ImpBlock decls []) = prettyLines $ fromNest decls - pretty (ImpBlock decls expr) = prettyLines decls' <> hardline <> pLowest expr - where decls' = fromNest decls - -instance Pretty (IBinder n l) where - pretty (IBinder b ty) = p b <+> ":" <+> p ty - -instance Pretty (ImpInstr n) where - pretty = \case - IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> - nest 4 (p block) - IWhile body -> "while" <+> nest 2 (p body) - ICond predicate cons alt -> - "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> - hardline <> "else" <> nest 2 (p alt) - IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s - ILaunch f size args -> - "launch" <+> p f <+> p size <+> spaced args - ICastOp t x -> "cast" <+> p x <+> "to" <+> p t - IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t - Store dest val -> "store" <+> p dest <+> p val - Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" - StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" - MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel - InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel - GetAllocSize ptr -> "getAllocSize" <+> p ptr - Free ptr -> "free" <+> p ptr - ISyncWorkgroup -> "syncWorkgroup" - IThrowError -> "throwError" - ICall f args -> "call" <+> p f <+> p args - IVectorBroadcast v _ -> "vbroadcast" <+> p v - IVectorIota _ -> "viota" - DebugPrint s x -> "debug_print" <+> p (show s) <+> p x - IPtrLoad ptr -> "load" <+> p ptr - IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx - IBinOp op x y -> opDefault (UBinOp op) [x, y] - IUnOp op x -> opDefault (UUnOp op) [x] - ISelect x y z -> "select" <+> p x <+> p y <+> p z - IOutputStream -> "outputStream" - IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x - where opDefault name xs = prettyOpDefault name xs $ AppPrec - -sizeStr :: IExpr n -> Doc ann -sizeStr s = case s of - ILit (Word32Lit x) -> p x -- print in decimal because it's more readable - _ -> p s - -instance Pretty BaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec BaseType where - prettyPrec b = case b of - Scalar sb -> prettyPrec sb - Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (p <$> shape) ++ [p sb] - PtrType ty -> atPrec AppPrec $ "Ptr" <+> p ty - -instance Pretty AddressSpace where pretty d = p (show d) - -instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec -instance PrettyPrec ScalarBaseType where - prettyPrec sb = atPrec ArgPrec $ case sb of - Int64Type -> "Int64" - Int32Type -> "Int32" - Float64Type -> "Float64" - Float32Type -> "Float32" - Word8Type -> "Word8" - Word32Type -> "Word32" - Word64Type -> "Word64" - -instance IRRep r => Pretty (TC r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (TC r n) where - prettyPrec con = case con of - BaseType b -> prettyPrec b - ProdType [] -> atPrec ArgPrec $ "()" - ProdType as -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pApp as - SumType cs -> atPrec ArgPrec $ align $ group $ - encloseSep "(|" "|)" " | " $ fmap pApp cs - RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a - TypeKind -> atPrec ArgPrec "Type" - HeapType -> atPrec ArgPrec "Heap" - -prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann -prettyPrecNewtype con x = case (con, x) of - (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n - (_, x') -> prettyPrec x' - -instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec -instance PrettyPrec (NewtypeTyCon n) where - prettyPrec = \case - Nat -> atPrec ArgPrec $ "Nat" - Fin n -> atPrec AppPrec $ "Fin" <+> pArg n - EffectRowKind -> atPrec ArgPrec "EffKind" - UserADTType "RangeTo" _ (TyConParams _ [i]) -> atPrec LowestPrec $ ".." <> pApp i - UserADTType "RangeToExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ "..<" <> pApp i - UserADTType "RangeFrom" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> ".." - UserADTType "RangeFromExc" _ (TyConParams _ [i]) -> atPrec LowestPrec $ pApp i <> "<.." - UserADTType name _ (TyConParams infs params) -> case (infs, params) of - ([], []) -> atPrec ArgPrec $ p name - ([Explicit, Explicit], [l, r]) - | Just sym <- fromInfix (fromString name) -> - atPrec ArgPrec $ align $ group $ - parens $ flatAlt " " "" <> pApp l <> line <> p sym <+> pApp r - _ -> atPrec LowestPrec $ pAppArg (p name) $ ignoreSynthParams (TyConParams infs params) - -instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Con r n) where - prettyPrec = \case - Lit l -> prettyPrec l - ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" - ProdCon xs -> atPrec ArgPrec $ align $ group $ - encloseSep "(" ")" ", " $ fmap pLowest xs - SumCon _ tag payload -> atPrec ArgPrec $ - "(" <> p tag <> "|" <+> pApp payload <+> "|)" - HeapVal -> atPrec ArgPrec "HeapValue" - -instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (PrimOp r n) where - prettyPrec = \case - MemOp op -> prettyPrec op - VectorOp op -> prettyPrec op - DAMOp op -> prettyPrec op - Hof (TypedHof _ hof) -> prettyPrec hof - RefOp ref eff -> atPrec LowestPrec case eff of - MAsk -> "ask" <+> pApp ref - MExtend _ x -> "extend" <+> pApp ref <+> pApp x - MGet -> "get" <+> pApp ref - MPut x -> pApp ref <+> ":=" <+> pApp x - IndexRef _ i -> pApp ref <+> "!" <+> pApp i - ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i - UnOp op x -> prettyOpDefault (UUnOp op) [x] - BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] - MiscOp op -> prettyOpGeneric op - -instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (MemOp r n) where - prettyPrec = \case - PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx - PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] - op -> prettyOpGeneric op - -instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (VectorOp r n) where - prettyPrec = \case - VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty - VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty - VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty - VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i - -prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann -prettyOpDefault name args = - case length args of - 0 -> atPrec ArgPrec primName - _ -> atPrec AppPrec $ pAppArg primName args - where primName = p name - -prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann -prettyOpGeneric op = case fromEGenericOpRep op of - GenericOpRep op' [] [] [] -> atPrec ArgPrec (p $ show op') - GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (p (show op')) xs <+> p ts <+> p lams - -instance Pretty PrimName where - pretty primName = p $ "%" ++ showPrimName primName - -instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (Hof r n) where - prettyPrec hof = atPrec LowestPrec case hof of - For _ _ lam -> "for" <+> pLowest lam - While body -> "while" <+> pArg body - RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) - RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) - RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) - RunIO body -> "runIO" <+> pArg body - RunInit body -> "runInit" <+> pArg body - CatchException _ body -> "catchException" <+> pArg body - Linearize body x -> "linearize" <+> pArg body <+> pArg x - Transpose body x -> "transpose" <+> pArg body <+> pArg x - -instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (DAMOp r n) where - prettyPrec op = atPrec LowestPrec case op of - Seq _ ann d c lamExpr -> case lamExpr of - UnaryLamExpr b body -> do - let rawFinPretty = case d of - IxType _ (IxDictRawFin n) -> parens $ "RawFin" <+> p n - _ -> mempty - "seq" <+> rawFinPretty <+> pApp ann <+> pApp c <+> prettyLam (p b <> ".") body - _ -> p (show op) -- shouldn't happen, but crashing pretty printers make debugging hard - RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y - Place r v -> pApp r <+> "r:=" <+> pApp v - Freeze r -> "freeze" <+> pApp r - AllocDest ty -> "alloc" <+> pApp ty - -instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec -instance IRRep r => PrettyPrec (BaseMonoid r n) where - prettyPrec (BaseMonoid x f) = - atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) - -instance PrettyPrec Direction where - prettyPrec d = atPrec ArgPrec $ case d of - Fwd -> "fwd" - Rev -> "rev" - -printDouble :: Double -> Doc ann -printDouble x = p (double2Float x) - -printFloat :: Float -> Doc ann -printFloat x = p $ reverse $ dropWhile (=='0') $ reverse $ - showFFloat (Just 6) x "" - -instance Pretty LitVal where pretty = prettyFromPrettyPrec -instance PrettyPrec LitVal where - prettyPrec (Int64Lit x) = atPrec ArgPrec $ p x - prettyPrec (Int32Lit x) = atPrec ArgPrec $ p x - prettyPrec (Float64Lit x) = atPrec ArgPrec $ printDouble x - prettyPrec (Float32Lit x) = atPrec ArgPrec $ printFloat x - prettyPrec (Word8Lit x) = atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x - prettyPrec (Word32Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (Word64Lit x) = atPrec ArgPrec $ p $ "0x" ++ showHex x "" - prettyPrec (PtrLit ty (PtrLitVal x)) = - atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) - prettyPrec (PtrLit _ NullPtr) = atPrec ArgPrec $ "NullPtr" - prettyPrec (PtrLit _ (PtrSnapshot _)) = atPrec ArgPrec "" - -instance Pretty CallingConvention where - pretty = p . show - -instance Pretty LetAnn where - pretty ann = case ann of - PlainLet -> "" - NoInlineLet -> "%noinline" - OccInfoPure u -> p u <> line - OccInfoImpure u -> p u <> ", impure" <> line - -instance Pretty UsageInfo where - pretty (UsageInfo static (ixDepth, ct)) = - "occurs in" <+> p static <+> "places, read" - <+> p ct <+> "times, to depth" <+> p (show ixDepth) - -instance Pretty Count where - pretty (Bounded ct) = "<=" <+> pretty ct - pretty Occ.Unbounded = "many" - -instance PrettyPrec () where prettyPrec = atPrec ArgPrec . pretty - -instance Pretty RWS where - pretty eff = case eff of - Reader -> "Read" - Writer -> "Accum" - State -> "State" - -printLitBlock :: Pretty block => Bool -> block -> Result -> String -printLitBlock isatty block result = pprint block ++ printResult isatty result - -printResult :: Bool -> Result -> String -printResult isatty (Result outs errs) = - concat (map printOutput outs) ++ case errs of - Success () -> "" - Failure err -> addColor isatty Red $ addPrefix ">" $ pprint err - where - printOutput :: Output -> String - printOutput out = addPrefix (addColor isatty Cyan ">") $ pprint $ out - -addPrefix :: String -> String -> String -addPrefix prefix str = unlines $ map prefixLine $ lines str - where prefixLine :: String -> String - prefixLine s = case s of "" -> prefix - _ -> prefix ++ " " ++ s - -addColor :: Bool -> ANSI.Color -> String -> String -addColor False _ s = s -addColor True c s = - setSGRCode [SetConsoleIntensity BoldIntensity, SetColor Foreground Vivid c] - ++ s ++ setSGRCode [Reset] - -toJSONStr :: ToJSON a => a -> String -toJSONStr = B.unpack . encode - -instance ToJSON Result where - toJSON (Result outs err) = object (outMaps <> errMaps) - where - errMaps = case err of - Failure e -> ["error" .= String (fromString $ pprint e)] - Success () -> [] - outMaps = flip foldMap outs $ \case - BenchResult name compileTime runTime _ -> - [ "bench_name" .= toJSON name - , "compile_time" .= toJSON compileTime - , "run_time" .= toJSON runTime ] - out -> ["result" .= String (fromString $ pprint out)] - --- === Concrete syntax rendering === - -instance Pretty SourceBlock' where - pretty (TopDecl decl) = p decl - pretty d = fromString $ show d - -instance Pretty CTopDecl where - pretty (WithSrc _ d) = p d - -instance Pretty CTopDecl' where - pretty (CSDecl ann decl) = annDoc <> p decl - where annDoc = case ann of - PlainLet -> mempty - _ -> p ann <> " " - pretty d = fromString $ show d - -instance Pretty CSDecl where - pretty (WithSrc _ d) = p d - -instance Pretty CSDecl' where - pretty = undefined - -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk - -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk - -- pretty (CDefDecl (CDef name args maybeAnn blk)) = - -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc - -- <> nest 2 (hardline <> p blk) - -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty - -- Nothing -> mempty - -- pretty (CInstance header givens methods name) = - -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where - -- name' = case name of - -- Nothing -> "instance " - -- (Just n) -> "named-instance " <> p n <> " " - -- pretty (CExpr e) = p e - -instance Pretty AppExplicitness where - pretty ExplicitApp = "->" - pretty ImplicitApp = "->>" - -instance Pretty CSBlock where - pretty (IndentedBlock decls) = nest 2 $ prettyLines decls - pretty (ExprBlock g) = pArg g - -instance PrettyPrec Group where - prettyPrec (WithSrc _ g) = prettyPrec g - -instance Pretty Group where - pretty = prettyFromPrettyPrec - -instance PrettyPrec Group' where - prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n - prettyPrec (CPrim prim args) = prettyOpDefault prim args - prettyPrec (CParens blk) = - atPrec ArgPrec $ "(" <> p blk <> ")" - prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g - prettyPrec (CBin (WithSrc _ JuxtaposeWithSpace) lhs rhs) = - atPrec AppPrec $ pApp lhs <+> pArg rhs - prettyPrec (CBin op lhs rhs) = - atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs - prettyPrec (CLambda args body) = - atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body - prettyPrec (CCase scrut alts) = - atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts - prettyPrec g = atPrec ArgPrec $ fromString $ show g - -instance Pretty Bin where - pretty (WithSrc _ b) = p b - -instance Pretty Bin' where - pretty (EvalBinOp name) = fromString name - pretty JuxtaposeWithSpace = " " - pretty JuxtaposeNoSpace = "" - pretty DepAmpersand = "&>" - pretty Dot = "." - pretty DepComma = ",>" - pretty Colon = ":" - pretty DoubleColon = "::" - pretty Dollar = "$" - pretty ImplicitArrow = "->>" - pretty FatArrow = "=>" - pretty Pipe = "|" - pretty CSEqual = "=" diff --git a/src/lib/PeepholeOptimize.hs b/src/lib/PeepholeOptimize.hs new file mode 100644 index 000000000..8ec599acd --- /dev/null +++ b/src/lib/PeepholeOptimize.hs @@ -0,0 +1,276 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module PeepholeOptimize (PeepholeOpt (..), peepholeExpr) where + +import Data.Word +import Data.Bits +import Data.List +import Data.Bits.Floating +import GHC.Float + +import Types.Core +import Types.Primitives +import Name +import IRVariants +import qualified Types.OpNames as P + +peepholeExpr :: Expr r n -> Expr r n +peepholeExpr e = case peephole e of + Just x -> Atom x + Nothing -> e +{-# INLINE peepholeExpr #-} + +-- === Peephole optimization = undefined + +-- These are context-free (no env!) optimizations of expressions and ops that +-- are worth doing unconditionally. Builder calls this automatically in `emit`. + +class ToExpr e r => PeepholeOpt (e::E) (r::IR) | e -> r where + peephole :: e n -> Maybe (Atom r n) + +instance PeepholeOpt (Expr r) r where + peephole = \case + Atom x -> Just x + PrimOp op -> peephole op + Project _ i x -> case x of + Con con -> Just case con of + ProdCon xs -> xs !! i + DepPair l _ _ | i == 0 -> l + DepPair _ r _ | i == 1 -> r + _ -> error "not a product" + Stuck _ _ -> Nothing + Unwrap _ x -> case x of + Con con -> Just case con of + NewtypeCon _ x' -> x' + _ -> error "not a newtype" + Stuck _ _ -> Nothing + App _ _ _ -> Nothing + TabApp _ _ _ -> Nothing + Case _ _ _ -> Nothing + TopApp _ _ _ -> Nothing + Block _ _ -> Nothing + TabCon _ _ _ -> Nothing + ApplyMethod _ _ _ _ -> Nothing + {-# INLINE peephole #-} + +instance PeepholeOpt (PrimOp r) r where + peephole = \case + MiscOp op -> peephole op + BinOp op l r -> peepholeBinOp op l r + _ -> Nothing + {-# INLINE peephole #-} + +peepholeBinOp :: P.BinOp -> Atom r n -> Atom r n -> Maybe (Atom r n) +peepholeBinOp op x y = case op of + IAdd -> case (x, y) of + (Con (Lit x'), y') | getIntLit x' == 0 -> Just y' + (x', Con (Lit y')) | getIntLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (+) x' y' + _ -> Nothing + ISub -> case (x, y) of + (x', Con (Lit y')) | getIntLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (-) x' y' + _ -> Nothing + IMul -> case (x, y) of + (Con (Lit x'), y') | getIntLit x' == 1 -> Just y' + (x', Con (Lit y')) | getIntLit y' == 1 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntBinOp (*) x' y' + _ -> Nothing + IDiv -> case (x, y) of + (x', Con (Lit y')) | getIntLit y' == 1 -> Just x' + _ -> Nothing + ICmp cop -> case (x, y) of + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyIntCmpOp (cmp cop) x' y' + _ -> Nothing + FAdd -> case (x, y) of + (Con (Lit x'), y') | getFloatLit x' == 0 -> Just y' + (x', Con (Lit y')) | getFloatLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (+) x' y' + _ -> Nothing + FSub -> case (x, y) of + (x', Con (Lit y')) | getFloatLit y' == 0 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (-) x' y' + _ -> Nothing + FMul -> case (x, y) of + (Con (Lit x'), y') | getFloatLit x' == 1 -> Just y' + (x', Con (Lit y')) | getFloatLit y' == 1 -> Just x' + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatBinOp (*) x' y' + _ -> Nothing + FDiv -> case (x, y) of + (x', Con (Lit y')) | getFloatLit y' == 1 -> Just x' + _ -> Nothing + FCmp cop -> case (x, y) of + (Con (Lit x'), Con (Lit y')) -> Just $ Con $ Lit $ applyFloatCmpOp (cmp cop) x' y' + _ -> Nothing + BOr -> case (x, y) of + (Con (Lit (Word8Lit x')), Con (Lit (Word8Lit y'))) -> Just $ Con $ Lit $ Word8Lit $ x' .|. y' + _ -> Nothing + BAnd -> case (x, y) of + (Con (Lit (Word8Lit lv)), Con (Lit (Word8Lit rv))) -> Just $ Con $ Lit $ Word8Lit $ lv .&. rv + _ -> Nothing + BXor -> Nothing -- TODO + BShL -> Nothing -- TODO + BShR -> Nothing -- TODO + IRem -> Nothing -- TODO + FPow -> Nothing -- TODO +{-# INLINE peepholeBinOp #-} + +instance PeepholeOpt (MiscOp r) r where + peephole = \case + CastOp (TyCon (BaseType (Scalar sTy))) (Con (Lit l)) -> case foldCast sTy l of + Just l' -> Just $ Con $ Lit l' + Nothing -> Nothing + ToEnum ty (Con (Lit (Word8Lit tag))) -> case ty of + TyCon (SumType cases) -> Just $ Con $ SumCon cases (fromIntegral tag) UnitVal + _ -> error "Ill typed ToEnum" + SumTag (Con (SumCon _ tag _)) -> Just $ Con $ Lit $ Word8Lit $ fromIntegral tag + Select p x y -> case p of + Con (Lit (Word8Lit p')) -> Just if p' /= 0 then x else y + _ -> Nothing + _ -> Nothing + +foldCast :: ScalarBaseType -> LitVal -> Maybe LitVal +foldCast sTy l = case sTy of + -- TODO: Check that the casts relating to floating-point agree with the + -- runtime behavior. The runtime is given by the `ICastOp` case in + -- ImpToLLVM.hs. We should make sure that the Haskell functions here + -- produce bitwise identical results to those instructions, by adjusting + -- either this or that as called for. + -- TODO: Also implement casts that may have unrepresentable results, i.e., + -- casting floating-point numbers to smaller floating-point numbers or to + -- fixed-point. Both of these necessarily have a much smaller dynamic range. + Int32Type -> case l of + Int32Lit _ -> Just l + Int64Lit i -> Just $ Int32Lit $ fromIntegral i + Word8Lit i -> Just $ Int32Lit $ fromIntegral i + Word32Lit i -> Just $ Int32Lit $ fromIntegral i + Word64Lit i -> Just $ Int32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Int64Type -> case l of + Int32Lit i -> Just $ Int64Lit $ fromIntegral i + Int64Lit _ -> Just l + Word8Lit i -> Just $ Int64Lit $ fromIntegral i + Word32Lit i -> Just $ Int64Lit $ fromIntegral i + Word64Lit i -> Just $ Int64Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word8Type -> case l of + Int32Lit i -> Just $ Word8Lit $ fromIntegral i + Int64Lit i -> Just $ Word8Lit $ fromIntegral i + Word8Lit _ -> Just l + Word32Lit i -> Just $ Word8Lit $ fromIntegral i + Word64Lit i -> Just $ Word8Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word32Type -> case l of + Int32Lit i -> Just $ Word32Lit $ fromIntegral i + Int64Lit i -> Just $ Word32Lit $ fromIntegral i + Word8Lit i -> Just $ Word32Lit $ fromIntegral i + Word32Lit _ -> Just l + Word64Lit i -> Just $ Word32Lit $ fromIntegral i + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Word64Type -> case l of + Int32Lit i -> Just $ Word64Lit $ fromIntegral (fromIntegral i :: Word32) + Int64Lit i -> Just $ Word64Lit $ fromIntegral i + Word8Lit i -> Just $ Word64Lit $ fromIntegral i + Word32Lit i -> Just $ Word64Lit $ fromIntegral i + Word64Lit _ -> Just l + Float32Lit _ -> Nothing + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float32Type -> case l of + Int32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Int64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float32Lit $ fromIntegral i + Word32Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Word64Lit i -> Just $ Float32Lit $ fixUlp i $ fromIntegral i + Float32Lit _ -> Just l + Float64Lit _ -> Nothing + PtrLit _ _ -> Nothing + Float64Type -> case l of + Int32Lit i -> Just $ Float64Lit $ fromIntegral i + Int64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Word8Lit i -> Just $ Float64Lit $ fromIntegral i + Word32Lit i -> Just $ Float64Lit $ fromIntegral i + Word64Lit i -> Just $ Float64Lit $ fixUlp i $ fromIntegral i + Float32Lit f -> Just $ Float64Lit $ float2Double f + Float64Lit _ -> Just l + PtrLit _ _ -> Nothing + where + -- When casting an integer type to a floating-point type of lower precision + -- (e.g., int32 to float32), GHC between 7.8.3 and 9.2.2 (exclusive) rounds + -- toward zero, instead of rounding to nearest even like everybody else. + -- See https://gitlab.haskell.org/ghc/ghc/-/issues/17231. + -- + -- We patch this by manually checking the two adjacent floats to the + -- candidate answer, and using one of those if the reverse cast is closer + -- to the original input. + -- + -- This rounds to nearest. We round to nearest *even* by considering the + -- candidates in decreasing order of the number of trailing zeros they + -- exhibit when cast back to the original integer type. + fixUlp :: forall a b w. (Num a, Integral a, FiniteBits a, RealFrac b, FloatingBits b w) + => a -> b -> b + fixUlp orig candidate = res where + res = closest $ sortBy moreLowBits [candidate, candidatem1, candidatep1] + candidatem1 = nextDown candidate + candidatep1 = nextUp candidate + closest = minimumBy (\ca cb -> err ca `compare` err cb) + err cand = absdiff orig (round cand) + absdiff a b = if a >= b then a - b else b - a + moreLowBits a b = + compare (0 - countTrailingZeros (round @b @a a)) + (0 - countTrailingZeros (round @b @a b)) + +-- === Helpers for function evaluation over fixed-width types === + +applyIntBinOp :: (forall a. (Num a, Integral a) => a -> a -> a) -> LitVal -> LitVal -> LitVal +applyIntBinOp f x y = case (x, y) of + (Int64Lit x', Int64Lit y') -> Int64Lit $ f x' y' + (Int32Lit x', Int32Lit y') -> Int32Lit $ f x' y' + (Word8Lit x', Word8Lit y') -> Word8Lit $ f x' y' + (Word32Lit x', Word32Lit y') -> Word32Lit $ f x' y' + (Word64Lit x', Word64Lit y') -> Word64Lit $ f x' y' + _ -> error "Expected integer atoms" + +applyIntCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> LitVal -> LitVal -> LitVal +applyIntCmpOp f x y = boolLit case (x, y) of + (Int64Lit x', Int64Lit y') -> f x' y' + (Int32Lit x', Int32Lit y') -> f x' y' + (Word8Lit x', Word8Lit y') -> f x' y' + (Word32Lit x', Word32Lit y') -> f x' y' + (Word64Lit x', Word64Lit y') -> f x' y' + _ -> error "Expected integer atoms" + +applyFloatBinOp :: (forall a. (Num a, Fractional a) => a -> a -> a) -> LitVal -> LitVal -> LitVal +applyFloatBinOp f x y = case (x, y) of + (Float64Lit x', Float64Lit y') -> Float64Lit $ f x' y' + (Float32Lit x', Float32Lit y') -> Float32Lit $ f x' y' + _ -> error "Expected float atoms" + +applyFloatCmpOp :: (forall a. (Eq a, Ord a) => a -> a -> Bool) -> LitVal -> LitVal -> LitVal +applyFloatCmpOp f x y = boolLit case (x, y) of + (Float64Lit x', Float64Lit y') -> f x' y' + (Float32Lit x', Float32Lit y') -> f x' y' + _ -> error "Expected float atoms" + +boolLit :: Bool -> LitVal +boolLit x = Word8Lit $ fromIntegral $ fromEnum x + +cmp :: Ord a => CmpOp -> a -> a -> Bool +cmp = \case + Less -> (<) + Greater -> (>) + Equal -> (==) + LessEqual -> (<=) + GreaterEqual -> (>=) diff --git a/src/lib/QueryType.hs b/src/lib/QueryType.hs index 50a976816..6229a1e8d 100644 --- a/src/lib/QueryType.hs +++ b/src/lib/QueryType.hs @@ -8,12 +8,15 @@ module QueryType (module QueryType, module QueryTypePure, toAtomVar) where import Control.Category ((>>>)) import Control.Monad +import Control.Applicative import Data.List (elemIndex) +import Data.Maybe (fromJust) import Data.Functor ((<&>)) import Types.Primitives import Types.Core import Types.Source +import Types.Top import Types.Imp import IRVariants import Core @@ -21,22 +24,17 @@ import Err import Name hiding (withFreshM) import Subst import Util -import PPrint () +import PPrint import QueryTypePure import CheapReduction -sourceNameType :: (EnvReader m, Fallible1 m) => SourceName -> m n (Type CoreIR n) -sourceNameType v = do - lookupSourceMap v >>= \case - Nothing -> throw UnboundVarErr $ pprint v - Just uvar -> getUVarType uvar -- === Exposed helpers for querying types and effects === caseAltsBinderTys :: (EnvReader m, IRRep r) => Type r n -> m n [Type r n] caseAltsBinderTys ty = case ty of - SumTy types -> return types - NewtypeTyCon t -> case t of + TyCon (SumType types) -> return types -- need this case? + TyCon (NewtypeTyCon t) -> case t of UserADTType _ defName params -> do def <- lookupTyCon defName ~(ADTCons cons) <- instantiateTyConDef def params @@ -48,20 +46,6 @@ caseAltsBinderTys ty = case ty of extendEffect :: IRRep r => Effect r n -> EffectRow r n -> EffectRow r n extendEffect eff (EffectRow effs t) = EffectRow (effs <> eSetSingleton eff) t -blockEffTy :: (EnvReader m, IRRep r) => Block r n -> m n (EffTy r n) -blockEffTy block = liftEnvReaderM $ refreshAbs block \decls result -> do - effs <- declsEffects decls mempty - return $ ignoreHoistFailure $ hoist decls $ EffTy effs $ getType result - where - declsEffects :: IRRep r => Nest (Decl r) n l -> EffectRow r l -> EnvReaderM l (EffectRow r l) - declsEffects Empty !acc = return acc - declsEffects n@(Nest (Let _ (DeclBinding _ expr)) rest) !acc = withExtEvidence n do - expr' <- sinkM expr - declsEffects rest $ acc <> getEffects expr' - -blockTy :: (EnvReader m, IRRep r) => Block r n -> m n (Type r n) -blockTy b = blockEffTy b <&> \(EffTy _ t) -> t - piTypeWithoutDest :: PiType SimpIR n -> PiType SimpIR n piTypeWithoutDest (PiType bsRefB _) = case popNest bsRefB of @@ -69,75 +53,48 @@ piTypeWithoutDest (PiType bsRefB _) = PiType bs $ EffTy Pure ansTy -- XXX: we ignore the effects here _ -> error "expected trailing dest binder" -blockEff :: (EnvReader m, IRRep r) => Block r n -> m n (EffectRow r n) -blockEff b = blockEffTy b <&> \(EffTy eff _) -> eff - -typeOfApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfApp (Pi piTy) xs = withSubstReaderT $ - withInstantiated piTy xs \(EffTy _ ty) -> substM ty -typeOfApp _ _ = error "expected a pi type" - -typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -typeOfTabApp t [] = return t -typeOfTabApp (TabPi tabTy) (i:rest) = do - resultTy <- instantiate tabTy [i] - typeOfTabApp resultTy rest +typeOfTabApp :: (IRRep r, EnvReader m) => Type r n -> Atom r n -> m n (Type r n) +typeOfTabApp (TyCon (TabPi tabTy)) i = instantiate tabTy [i] typeOfTabApp ty _ = error $ "expected a table type. Got: " ++ pprint ty -typeOfApplyMethod :: EnvReader m => CAtom n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) +typeOfApplyMethod :: EnvReader m => CDict n -> Int -> [CAtom n] -> m n (EffTy CoreIR n) typeOfApplyMethod d i args = do - ty <- Pi <$> getMethodType d i + ty <- toType <$> getMethodType d i appEffTy ty args -typeOfDictExpr :: EnvReader m => DictExpr n -> m n (CType n) -typeOfDictExpr e = liftM ignoreExcept $ liftEnvReaderT $ case e of - InstanceDict instanceName args -> do - instanceDef@(InstanceDef className _ _ _ _) <- lookupInstanceDef instanceName - sourceName <- getSourceName <$> lookupClassDef className - PairE (ListE params) _ <- instantiate instanceDef args - return $ DictTy $ DictType sourceName className params - InstantiatedGiven given args -> typeOfApp (getType given) args - SuperclassProj d i -> do - DictTy (DictType _ className params) <- return $ getType d - classDef <- lookupClassDef className - withSubstReaderT $ withInstantiated classDef params \(Abs superclasses _) -> do - substM $ getSuperclassType REmpty superclasses i - IxFin n -> liftM DictTy $ ixDictType $ NewtypeTyCon $ Fin n - DataData ty -> DictTy <$> dataDictType ty - typeOfTopApp :: EnvReader m => TopFunName n -> [SAtom n] -> m n (EffTy SimpIR n) typeOfTopApp f xs = do piTy <- getTypeTopFun f instantiate piTy xs typeOfIndexRef :: (EnvReader m, Fallible1 m, IRRep r) => Type r n -> Atom r n -> m n (Type r n) -typeOfIndexRef (TC (RefType h s)) i = do - TabPi tabPi <- return s +typeOfIndexRef (TyCon (RefType h s)) i = do + TyCon (TabPi tabPi) <- return s eltTy <- instantiate tabPi [i] - return $ TC $ RefType h eltTy + return $ toType $ RefType h eltTy typeOfIndexRef _ _ = error "expected a ref type" typeOfProjRef :: EnvReader m => Type r n -> Projection -> m n (Type r n) -typeOfProjRef (TC (RefType h s)) p = do - TC . RefType h <$> case p of +typeOfProjRef (TyCon (RefType h s)) p = do + toType . RefType h <$> case p of ProjectProduct i -> do - ~(ProdTy tys) <- return s + ~(TyCon (ProdType tys)) <- return s return $ tys !! i UnwrapNewtype -> do case s of - NewtypeTyCon tc -> snd <$> unwrapNewtypeType tc + TyCon (NewtypeTyCon tc) -> snd <$> unwrapNewtypeType tc _ -> error "expected a newtype" typeOfProjRef _ _ = error "expected a reference" appEffTy :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (EffTy r n) -appEffTy (Pi piTy) xs = instantiate piTy xs +appEffTy (TyCon (Pi piTy)) xs = instantiate piTy xs appEffTy t _ = error $ "expected a pi type, got: " ++ pprint t partialAppType :: (IRRep r, EnvReader m) => Type r n -> [Atom r n] -> m n (Type r n) -partialAppType (Pi (CorePiType appExpl expls bs effTy)) xs = do +partialAppType (TyCon (Pi (CorePiType appExpl expls bs effTy))) xs = do (_, expls2) <- return $ splitAt (length xs) expls PairB bs1 bs2 <- return $ splitNestAt (length xs) bs - instantiate (Abs bs1 (Pi $ CorePiType appExpl expls2 bs2 effTy)) xs + instantiate (Abs bs1 (toType $ CorePiType appExpl expls2 bs2 effTy)) xs partialAppType _ _ = error "expected a pi type" effTyOfHof :: (EnvReader m, IRRep r) => Hof r n -> m n (EffTy r n) @@ -152,7 +109,7 @@ typeOfHof = \case Linearize f _ -> getLamExprType f >>= \case PiType (UnaryNest (binder:>a)) (EffTy Pure b) -> do let b' = ignoreHoistFailure $ hoist binder b - let fLinTy = Pi $ nonDepPiType [a] Pure b' + let fLinTy = toType $ nonDepPiType [a] Pure b' return $ PairTy b' fLinTy _ -> error "expected a unary pi type" Transpose f _ -> getLamExprType f >>= \case @@ -165,22 +122,22 @@ typeOfHof = \case RunState _ _ f -> do (resultTy, stateTy) <- getTypeRWSAction f return $ PairTy resultTy stateTy - RunIO f -> blockTy f - RunInit f -> blockTy f + RunIO f -> return $ getType f + RunInit f -> return $ getType f CatchException ty _ -> return ty hofEffects :: (EnvReader m, IRRep r) => Hof r n -> m n (EffectRow r n) hofEffects = \case For _ _ f -> functionEffs f - While body -> blockEff body + While body -> return $ getEffects body Linearize _ _ -> return Pure -- Body has to be a pure function Transpose _ _ -> return Pure -- Body has to be a pure function RunReader _ f -> rwsFunEffects Reader f RunWriter d _ f -> maybeInit d <$> rwsFunEffects Writer f RunState d _ f -> maybeInit d <$> rwsFunEffects State f - RunIO f -> deleteEff IOEffect <$> blockEff f - RunInit f -> deleteEff InitEffect <$> blockEff f - CatchException _ f -> deleteEff ExceptionEffect <$> blockEff f + RunIO f -> return $ deleteEff IOEffect $ getEffects f + RunInit f -> return $ deleteEff InitEffect $ getEffects f + CatchException _ f -> return $ deleteEff ExceptionEffect $ getEffects f where maybeInit :: IRRep r => Maybe (Atom r i) -> (EffectRow r o -> EffectRow r o) maybeInit d = case d of Just _ -> (<>OneEffect InitEffect); Nothing -> id @@ -189,9 +146,9 @@ deleteEff eff (EffectRow effs t) = EffectRow (effs `eSetDifference` eSetSingleto getMethodIndex :: EnvReader m => ClassName n -> SourceName -> m n Int getMethodIndex className methodSourceName = do - ClassDef _ methodNames _ _ _ _ _ <- lookupClassDef className + ClassDef _ _ methodNames _ _ _ _ _ <- lookupClassDef className case elemIndex methodSourceName methodNames of - Nothing -> error $ methodSourceName ++ " is not a method of " ++ pprint className + Nothing -> error $ pprint methodSourceName ++ " is not a method of " ++ pprint className Just i -> return i {-# INLINE getMethodIndex #-} @@ -202,41 +159,52 @@ getUVarType = \case UDataConVar v -> getDataConNameType v UPunVar v -> getStructDataConType v UClassVar v -> do - ClassDef _ _ _ roleExpls bs _ _ <- lookupClassDef v - return $ Pi $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind + ClassDef _ _ _ _ roleExpls bs _ _ <- lookupClassDef v + return $ toType $ CorePiType ExplicitApp (map snd roleExpls) bs $ EffTy Pure TyKind UMethodVar v -> getMethodNameType v - UEffectVar _ -> error "not implemented" - UEffectOpVar _ -> error "not implemented" getMethodNameType :: EnvReader m => MethodName n -> m n (CType n) getMethodNameType v = liftEnvReaderM $ lookupEnv v >>= \case MethodBinding className i -> do - ClassDef _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className + ClassDef _ _ _ paramNames _ paramBs scBinders methodTys <- lookupClassDef className refreshAbs (Abs paramBs $ Abs scBinders (methodTys !! i)) \paramBs' absPiTy -> do - let params = Var <$> bindersVars paramBs' - dictTy <- DictTy <$> dictType (sink className) params + let params = toAtom <$> bindersVars paramBs' + dictTy <- toType <$> dictType (sink className) params withFreshBinder noHint dictTy \dictB -> do - scDicts <- getSuperclassDicts (Var $ binderVar dictB) + scDicts <- getSuperclassDicts (toDict $ binderVar dictB) CorePiType appExpl methodExpls methodBs effTy <- instantiate (sink absPiTy) scDicts let paramExpls = paramNames <&> \name -> Inferred name Unify let expls = paramExpls <> [Inferred Nothing (Synth $ Partial $ succ i)] <> methodExpls - return $ Pi $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy - -getMethodType :: EnvReader m => Dict n -> Int -> m n (CorePiType n) -getMethodType dict i = liftEnvReaderM $ withSubstReaderT do - ~(DictTy (DictType _ className params)) <- return $ getType dict - superclassDicts <- getSuperclassDicts dict - classDef <- lookupClassDef className - withInstantiated classDef params \ab -> do - withInstantiated ab superclassDicts \(ListE methodTys) -> - substM $ methodTys !! i + return $ toType $ CorePiType appExpl expls (paramBs' >>> UnaryNest dictB >>> methodBs) effTy + +getMethodType :: EnvReader m => CDict n -> Int -> m n (CorePiType n) +getMethodType dict i = do + ~(TyCon (DictTy dictTy)) <- return $ getType dict + case dictTy of + DictType _ className params -> liftEnvReaderM $ withSubstReaderT do + superclassDicts <- getSuperclassDicts dict + classDef <- lookupClassDef className + withInstantiated classDef params \ab -> do + withInstantiated ab superclassDicts \(ListE methodTys) -> + substM $ methodTys !! i + IxDictType ixTy -> liftEnvReaderM case i of + 0 -> mkCorePiType [] NatTy -- size' : () -> Nat + 1 -> mkCorePiType [ixTy] NatTy -- ordinal : (n) -> Nat + 2 -> mkCorePiType [NatTy] ixTy -- unsafe_from_ordinal : (Nat) -> n + _ -> error "Ix only has three methods" + DataDictType _ -> error "Data class has no methods" + +mkCorePiType :: EnvReader m => [CType n] -> CType n -> m n (CorePiType n) +mkCorePiType argTys resultTy = liftEnvReaderM $ withFreshBinders argTys \bs _ -> do + expls <- return $ nestToList (const Explicit) bs + return $ CorePiType ExplicitApp expls bs (EffTy Pure (sink resultTy)) getTyConNameType :: EnvReader m => TyConName n -> m n (Type CoreIR n) getTyConNameType v = do TyConDef _ expls bs _ <- lookupTyCon v case bs of Empty -> return TyKind - _ -> return $ Pi $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind + _ -> return $ toType $ CorePiType ExplicitApp (snd <$> expls) bs $ EffTy Pure TyKind getDataConNameType :: EnvReader m => DataConName n -> m n (Type CoreIR n) getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do @@ -248,9 +216,9 @@ getDataConNameType dataCon = liftEnvReaderM $ withSubstReaderT do refreshAbs ab \dataBs UnitE -> do let appExpl = case dataBs of Empty -> ImplicitApp _ -> ExplicitApp - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) + let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) (sink params) let dataExpls = nestToList (const $ Explicit) dataBs - return $ Pi $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) + return $ toType $ CorePiType appExpl (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy) getStructDataConType :: EnvReader m => TyConName n -> m n (CType n) getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do @@ -258,10 +226,10 @@ getStructDataConType tyCon = liftEnvReaderM $ withSubstReaderT do buildDataConType tyConDef \expls paramBs' paramVs params -> do withInstantiatedNames tyConDef paramVs \(StructFields fields) -> do fieldTys <- forM fields \(_, t) -> renameM t - let resultTy = NewtypeTyCon $ UserADTType (getSourceName tyConDef) (sink tyCon) params + let resultTy = toType $ UserADTType (getSourceName tyConDef) (sink tyCon) params Abs dataBs resultTy' <- return $ typesAsBinderNest fieldTys resultTy let dataExpls = nestToList (const Explicit) dataBs - return $ Pi $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') + return $ toType $ CorePiType ExplicitApp (expls <> dataExpls) (paramBs' >>> dataBs) (EffTy Pure resultTy') buildDataConType :: (EnvReader m, EnvExtender m) @@ -276,44 +244,29 @@ buildDataConType (TyConDef _ roleExpls bs _) cont = do refreshAbs (Abs bs UnitE) \bs' UnitE -> do let vs = nestToNames bs' vs' <- mapM toAtomVar vs - cont expls' bs' vs $ TyConParams expls (Var <$> vs') + cont expls' bs' vs $ TyConParams expls (toAtom <$> vs') makeTyConParams :: EnvReader m => TyConName n -> [CAtom n] -> m n (TyConParams n) makeTyConParams tc params = do TyConDef _ expls _ _ <- lookupTyCon tc return $ TyConParams (map snd expls) params -getDataClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getDataClassName = lookupSourceMap "Data" >>= \case - Nothing -> throw CompilerErr $ "Data interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - -dataDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -dataDictType ty = do - dataClassName <- getDataClassName - dictType dataClassName [Type ty] - -getIxClassName :: (Fallible1 m, EnvReader m) => m n (ClassName n) -getIxClassName = lookupSourceMap "Ix" >>= \case - Nothing -> throw CompilerErr $ "Ix interface needed but not defined!" - Just (UClassVar v) -> return v - Just _ -> error "not a class var" - dictType :: EnvReader m => ClassName n -> [CAtom n] -> m n (DictType n) dictType className params = do - ClassDef sourceName _ _ _ _ _ _ <- lookupClassDef className - return $ DictType sourceName className params - -ixDictType :: (Fallible1 m, EnvReader m) => CType n -> m n (DictType n) -ixDictType ty = do - ixClassName <- getIxClassName - dictType ixClassName [Type ty] + ClassDef sourceName builtinName _ _ _ _ _ _ <- lookupClassDef className + return case builtinName of + Just Ix -> IxDictType singleTyParam + Just Data -> DataDictType singleTyParam + Nothing -> DictType sourceName className params + where singleTyParam = case params of + [p] -> fromJust $ toMaybeType p + _ -> error "not a single type param" makePreludeMaybeTy :: EnvReader m => CType n -> m n (CType n) makePreludeMaybeTy ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" - return $ TypeCon "Maybe" tyConName $ TyConParams [Explicit] [Type ty] + let params = TyConParams [Explicit] [toAtom ty] + return $ toType $ UserADTType "Maybe" tyConName params -- === computing effects === @@ -325,16 +278,14 @@ rwsFunEffects :: (IRRep r, EnvReader m) => RWS -> LamExpr r n -> m n (EffectRow rwsFunEffects rws f = getLamExprType f >>= \case PiType (BinaryNest h ref) et -> do let effs' = ignoreHoistFailure $ hoist ref (etEff et) - let hVal = Var $ AtomVar (binderName h) (TC HeapType) + let hVal = toAtom $ AtomVar (binderName h) (TyCon HeapType) let effs'' = deleteEff (RWSEffect rws hVal) effs' return $ ignoreHoistFailure $ hoist h effs'' _ -> error "Expected a binary function type" getLamExprType :: (IRRep r, EnvReader m) => LamExpr r n -> m n (PiType r n) -getLamExprType (LamExpr bs body) = liftEnvReaderM $ - refreshAbs (Abs bs body) \bs' body' -> do - effTy <- blockEffTy body' - return $ PiType bs' effTy +getLamExprType (LamExpr bs body) = + return $ PiType bs $ EffTy (getEffects body) (getType body) getTypeRWSAction :: (IRRep r, EnvReader m) => LamExpr r n -> m n (Type r n, Type r n) getTypeRWSAction f = getLamExprType f >>= \case @@ -347,19 +298,22 @@ getTypeRWSAction f = getLamExprType f >>= \case _ -> error "expected a ref" _ -> error "expected a pi type" -getSuperclassDicts :: EnvReader m => CAtom n -> m n ([CAtom n]) +getSuperclassDicts :: EnvReader m => CDict n -> m n ([CAtom n]) getSuperclassDicts dict = do case getType dict of - DictTy dTy -> do + TyCon (DictTy dTy) -> do ts <- getSuperclassTys dTy - forM (enumerate ts) \(i, t) -> return $ DictCon t $ SuperclassProj dict i + forM (enumerate ts) \(i, _) -> reduceSuperclassProj i dict _ -> error "expected a dict type" getSuperclassTys :: EnvReader m => DictType n -> m n [CType n] -getSuperclassTys (DictType _ className params) = do - ClassDef _ _ _ _ bs superclasses _ <- lookupClassDef className - forM [0 .. nestLength superclasses - 1] \i -> do - instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params +getSuperclassTys = \case + DictType _ className params -> do + ClassDef _ _ _ _ _ bs superclasses _ <- lookupClassDef className + forM [0 .. nestLength superclasses - 1] \i -> do + instantiate (Abs bs $ getSuperclassType REmpty superclasses i) params + DataDictType _ -> return [] + IxDictType ty -> return [toType $ DataDictType ty] getTypeTopFun :: EnvReader m => TopFunName n -> m n (PiType SimpIR n) getTypeTopFun f = lookupTopFun f >>= \case @@ -378,10 +332,10 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where [] -> return $ PiType Empty (EffTy (OneEffect IOEffect) resultTy) where resultTy = case resultTys of [] -> UnitTy - [t] -> BaseTy t - [t1, t2] -> PairTy (BaseTy t1) (BaseTy t2) + [t] -> toType $ BaseType t + [t1, t2] -> TyCon (ProdType [toType $ BaseType t1, toType $ BaseType t2]) _ -> error $ "Not a valid FFI return type: " ++ pprint resultTys - t:ts -> withFreshBinder noHint (BaseTy t) \b -> do + t:ts -> withFreshBinder noHint (toType $ BaseType t) \b -> do PiType bs effTy <- go ts return $ PiType (Nest b bs) effTy @@ -389,34 +343,29 @@ liftIFunType (IFunType _ argTys resultTys) = liftEnvReaderM $ go argTys where isData :: EnvReader m => Type CoreIR n -> m n Bool isData ty = do - result <- liftEnvReaderT $ withSubstReaderT $ checkDataLike ty - case runFallibleM result of - Success () -> return True - Failure _ -> return False - -checkDataLike :: Type CoreIR i -> SubstReaderT Name FallibleEnvReaderM i o () -checkDataLike ty = case ty of - TyVar _ -> notData - TabPi (TabPiType _ b eltTy) -> do - renameBinders b \_ -> - checkDataLike eltTy - DepPairTy (DepPairType _ b@(_:>l) r) -> do - recur l - renameBinders b \_ -> checkDataLike r - NewtypeTyCon nt -> do - (_, ty') <- unwrapNewtypeType =<< renameM nt - dropSubst $ recur ty' - TC con -> case con of - BaseType _ -> return () - ProdType as -> mapM_ recur as - SumType cs -> mapM_ recur cs - RefType _ _ -> return () - HeapType -> return () - _ -> notData - _ -> notData + result <- liftEnvReaderT $ withSubstReaderT $ go ty + case result of + Just () -> return True + Nothing -> return False where - recur = checkDataLike - notData = throw TypeErr $ pprint ty + go :: Type CoreIR i -> SubstReaderT Name (EnvReaderT Maybe) i o () + go = \case + StuckTy _ _ -> notData + TyCon con -> case con of + TabPi (TabPiType _ b eltTy) -> renameBinders b \_ -> go eltTy + DepPairTy (DepPairType _ b@(_:>l) r) -> go l >> renameBinders b \_ -> go r + NewtypeTyCon nt -> do + (_, ty') <- unwrapNewtypeType =<< renameM nt + dropSubst $ go ty' + BaseType _ -> return () + ProdType as -> mapM_ go as + SumType cs -> mapM_ go cs + RefType _ _ -> return () + HeapType -> return () + TypeKind -> notData + DictTy _ -> notData + Pi _ -> notData + where notData = empty checkExtends :: (Fallible m, IRRep r) => EffectRow r n -> EffectRow r n -> m () checkExtends allowed (EffectRow effs effTail) = do @@ -425,6 +374,6 @@ checkExtends allowed (EffectRow effs effTail) = do EffectRowTail _ -> assertEq allowedEffTail effTail "" NoTail -> return () forM_ (eSetToList effs) \eff -> unless (eff `eSetMember` allowedEffs) $ - throw CompilerErr $ "Unexpected effect: " ++ pprint eff ++ - "\nAllowed: " ++ pprint allowed - + throwInternal $ "Unexpected effect: " ++ pprint eff ++ + "\nAllowed: " ++ pprint allowed +{-# INLINE checkExtends #-} diff --git a/src/lib/QueryTypePure.hs b/src/lib/QueryTypePure.hs index 9be267241..153ed5449 100644 --- a/src/lib/QueryTypePure.hs +++ b/src/lib/QueryTypePure.hs @@ -8,6 +8,7 @@ module QueryTypePure where import Types.Primitives import Types.Core +import Types.Top import IRVariants import Name @@ -17,6 +18,9 @@ class HasType (r::IR) (e::E) | e -> r where class HasEffects (e::E) (r::IR) | e -> r where getEffects :: e n -> EffectRow r n +getTyCon :: HasType SimpIR e => e n -> TyCon SimpIR n +getTyCon e = con where TyCon con = getType e + isPure :: (IRRep r, HasEffects e r) => e n -> Bool isPure e = case getEffects e of Pure -> True @@ -28,11 +32,12 @@ instance IRRep r => HasType r (AtomBinding r) where getType = \case LetBound (DeclBinding _ e) -> getType e MiscBound ty -> ty - SolverBound (InfVarBound ty _) -> ty + SolverBound (InfVarBound ty) -> ty SolverBound (SkolemBound ty) -> ty + SolverBound (DictBound ty) -> ty NoinlineFun ty _ -> ty - TopDataBound (RepVal ty _) -> ty - FFIFunBound piTy _ -> Pi piTy + TopDataBound e -> getType e + FFIFunBound piTy _ -> TyCon $ Pi piTy litType :: LitVal -> BaseType litType v = case v of @@ -68,54 +73,48 @@ instance IRRep r => HasType r (AtomVar r) where {-# INLINE getType #-} instance IRRep r => HasType r (Atom r) where - getType atom = case atom of - Var name -> getType name - Lam (CoreLamExpr piTy _) -> Pi piTy - DepPair _ _ ty -> DepPairTy ty - Con con -> getType con - Eff _ -> EffKind - PtrVar t _ -> PtrTy t - DictCon ty _ -> ty - NewtypeCon con _ -> getNewtypeType con - RepValAtom (RepVal ty _) -> ty - ProjectElt t _ _ -> t - SimpInCore x -> getType x - DictHole _ ty _ -> ty - TypeAsAtom ty -> getType ty + getType = \case + Stuck t _ -> t + Con e -> getType e -instance IRRep r => HasType r (Type r) where +instance HasType CoreIR (Dict CoreIR) where getType = \case - NewtypeTyCon con -> getType con - Pi _ -> TyKind - TabPi _ -> TyKind - DepPairTy _ -> TyKind - TC _ -> TyKind - DictTy _ -> TyKind - TyVar v -> getType v - ProjectEltTy t _ _ -> t - -instance HasType CoreIR SimpInCore where + StuckDict t _ -> t + DictCon e -> getType e + +instance HasType CoreIR (DictCon CoreIR) where getType = \case - LiftSimp t _ -> t - LiftSimpFun piTy _ -> Pi $ piTy - TabLam t _ -> TabPi $ t - ACase _ _ t -> t + InstanceDict t _ _ -> t + DataData t -> toType $ DataDictType t + IxFin n -> toType $ IxDictType (FinTy n) + IxRawFin _ -> toType $ IxDictType IdxRepTy + +instance HasType CoreIR CType where + getType = \case + TyCon _ -> TyKind + StuckTy t _ -> t instance HasType CoreIR NewtypeTyCon where getType _ = TyKind getNewtypeType :: NewtypeCon n -> CType n getNewtypeType con = case con of - NatCon -> NewtypeTyCon Nat - FinCon n -> NewtypeTyCon $ Fin n - UserADTData sn d params -> NewtypeTyCon $ UserADTType sn d params + NatCon -> TyCon $ NewtypeTyCon Nat + FinCon n -> TyCon $ NewtypeTyCon $ Fin n + UserADTData sn d xs -> TyCon $ NewtypeTyCon $ UserADTType sn d xs instance IRRep r => HasType r (Con r) where getType = \case - Lit l -> BaseTy $ litType l - ProdCon xs -> ProdTy $ map getType xs - SumCon tys _ _ -> SumTy tys - HeapVal -> TC HeapType + Lit l -> toType $ BaseType $ litType l + ProdCon xs -> toType $ ProdType $ map getType xs + SumCon tys _ _ -> toType $ SumType tys + HeapVal -> toType HeapType + Lam (CoreLamExpr piTy _) -> toType $ Pi piTy + DepPair _ _ ty -> toType $ DepPairTy ty + Eff _ -> EffKind + DictConAtom d -> getType d + NewtypeCon con _ -> getNewtypeType con + TyConAtom _ -> TyKind getSuperclassType :: RNest CBinder n l -> Nest CBinder l l' -> Int -> CType n getSuperclassType _ Empty = error "bad index" @@ -129,10 +128,16 @@ instance IRRep r => HasType r (Expr r) where TopApp (EffTy _ ty) _ _ -> ty TabApp t _ _ -> t Atom x -> getType x + Block (EffTy _ ty) _ -> ty TabCon _ ty _ -> ty PrimOp op -> getType op Case _ _ (EffTy _ resultTy) -> resultTy ApplyMethod (EffTy _ t) _ _ _ -> t + Project t _ _ -> t + Unwrap t _ -> t + +instance HasType SimpIR RepVal where + getType (RepVal ty _) = ty instance IRRep r => HasType r (DAMOp r) where getType = \case @@ -146,15 +151,15 @@ instance IRRep r => HasType r (DAMOp r) where instance IRRep r => HasType r (PrimOp r) where getType primOp = case primOp of - BinOp op x _ -> TC $ BaseType $ typeBinOp op $ getTypeBaseType x - UnOp op x -> TC $ BaseType $ typeUnOp op $ getTypeBaseType x + BinOp op x _ -> TyCon $ BaseType $ typeBinOp op $ getTypeBaseType x + UnOp op x -> TyCon $ BaseType $ typeUnOp op $ getTypeBaseType x Hof (TypedHof (EffTy _ ty) _) -> ty MemOp op -> getType op MiscOp op -> getType op VectorOp op -> getType op DAMOp op -> getType op RefOp ref m -> case getType ref of - TC (RefType _ s) -> case m of + TyCon (RefType _ s) -> case m of MGet -> s MPut _ -> UnitTy MAsk -> s @@ -165,7 +170,7 @@ instance IRRep r => HasType r (PrimOp r) where getTypeBaseType :: (IRRep r, HasType r e) => e n -> BaseType getTypeBaseType e = case getType e of - TC (BaseType b) -> b + TyCon (BaseType b) -> b ty -> error $ "Expected a base type. Got: " ++ show ty instance IRRep r => HasType r (MemOp r) where @@ -175,7 +180,7 @@ instance IRRep r => HasType r (MemOp r) where PtrOffset arr _ -> getType arr PtrLoad ptr -> do let PtrTy (_, t) = getType ptr - BaseTy t + toType $ BaseType t PtrStore _ _ -> UnitTy instance IRRep r => HasType r (VectorOp r) where @@ -184,7 +189,7 @@ instance IRRep r => HasType r (VectorOp r) where VectorIota vty -> vty VectorIdx _ _ vty -> vty VectorSubref ref _ vty -> case getType ref of - TC (RefType h _) -> TC $ RefType h vty + TyCon (RefType h _) -> TyCon $ RefType h vty ty -> error $ "Not a reference type: " ++ show ty instance IRRep r => HasType r (MiscOp r) where @@ -198,20 +203,20 @@ instance IRRep r => HasType r (MiscOp r) where GarbageVal t -> t SumTag _ -> TagRepTy ToEnum t _ -> t - OutputStream -> BaseTy $ hostPtrTy $ Scalar Word8Type + OutputStream -> toType $ BaseType $ hostPtrTy $ Scalar Word8Type where hostPtrTy ty = PtrType (CPU, ty) ShowAny _ -> rawStrType -- TODO: constrain `ShowAny` to have `HasCore r` - ShowScalar _ -> PairTy IdxRepTy $ rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy + ShowScalar _ -> toType $ ProdType [IdxRepTy, rawFinTabType (IdxRepVal showStringBufferSize) CharRepTy] rawStrType :: IRRep r => Type r n rawStrType = case newName "n" of Abs b v -> do - let tabTy = rawFinTabType (Var $ AtomVar v IdxRepTy) CharRepTy - DepPairTy $ DepPairType ExplicitDepPair (b:>IdxRepTy) tabTy + let tabTy = rawFinTabType (toAtom $ AtomVar v IdxRepTy) CharRepTy + TyCon $ DepPairTy $ DepPairType ExplicitDepPair (b:>IdxRepTy) tabTy -- `n` argument is IdxRepVal, not Nat rawFinTabType :: IRRep r => Atom r n -> Type r n -> Type r n -rawFinTabType n eltTy = IxType IdxRepTy (IxDictRawFin n) ==> eltTy +rawFinTabType n eltTy = IxType IdxRepTy (DictCon (IxRawFin n)) ==> eltTy tabIxType :: TabPiType r n -> IxType r n tabIxType (TabPiType d (_:>t) _) = IxType t d @@ -239,27 +244,20 @@ coreLamToTopLam :: CoreLamExpr n -> TopLam CoreIR n coreLamToTopLam (CoreLamExpr ty f) = TopLam False (corePiTypeToPiType ty) f (==>) :: IRRep r => IxType r n -> Type r n -> Type r n -a ==> b = TabPi $ nonDepTabPiType a b +a ==> b = TyCon $ TabPi $ nonDepTabPiType a b -litFinIxTy :: Int -> IxType r n +litFinIxTy :: Int -> IxType SimpIR n litFinIxTy n = finIxTy $ IdxRepVal $ fromIntegral n -finIxTy :: Atom r n -> IxType r n -finIxTy n = IxType IdxRepTy (IxDictRawFin n) - -ixTyFromDict :: IRRep r => IxDict r n -> IxType r n -ixTyFromDict ixDict = flip IxType ixDict $ case ixDict of - IxDictAtom dict -> case getType dict of - DictTy (DictType "Ix" _ [Type iTy]) -> iTy - _ -> error $ "Not an Ix dict: " ++ show dict - IxDictRawFin _ -> IdxRepTy - IxDictSpecialized n _ _ -> n +finIxTy :: Atom SimpIR n -> IxType SimpIR n +finIxTy n = IxType IdxRepTy (DictCon (IxRawFin n)) -- === querying effects implementation === instance IRRep r => HasEffects (Expr r) r where getEffects = \case Atom _ -> Pure + Block (EffTy eff _) _ -> eff App (EffTy eff _) _ _ -> eff TopApp (EffTy eff _) _ _ -> eff TabApp _ _ _ -> Pure @@ -267,6 +265,8 @@ instance IRRep r => HasEffects (Expr r) r where TabCon _ _ _ -> Pure ApplyMethod (EffTy eff _) _ _ _ -> eff PrimOp primOp -> getEffects primOp + Project _ _ _ -> Pure + Unwrap _ _ -> Pure instance IRRep r => HasEffects (DeclBinding r) r where getEffects (DeclBinding _ expr) = getEffects expr @@ -297,7 +297,7 @@ instance IRRep r => HasEffects (PrimOp r) r where ShowAny _ -> Pure ShowScalar _ -> Pure RefOp ref m -> case getType ref of - TC (RefType h _) -> case m of + TyCon (RefType h _) -> case m of MGet -> OneEffect (RWSEffect State h) MPut _ -> OneEffect (RWSEffect State h) MAsk -> OneEffect (RWSEffect Reader h) diff --git a/src/lib/RenderHtml.hs b/src/lib/RenderHtml.hs index f87d2accb..da8e8a28c 100644 --- a/src/lib/RenderHtml.hs +++ b/src/lib/RenderHtml.hs @@ -7,30 +7,155 @@ {-# OPTIONS_GHC -Wno-orphans #-} {-# OPTIONS_GHC -Wno-incomplete-patterns #-} -module RenderHtml (pprintHtml, progHtml, ToMarkup, treeToHtml) where +module RenderHtml ( + progHtml, ToMarkup, renderSourceBlock, renderOutputs, + RenderedSourceBlock, RenderedOutputs) where -import Text.Blaze.Html5 as H hiding (map) +import Text.Blaze.Internal (MarkupM) +import Text.Blaze.Html5 as H hiding (map, b) import Text.Blaze.Html5.Attributes as At import Text.Blaze.Html.Renderer.String -import Data.List qualified as L +import Data.Aeson (ToJSON) +import qualified Data.Map.Strict as M +import Control.Monad.State.Strict +import Control.Monad.Writer.Strict +import Data.Foldable (fold) +import Data.Functor ((<&>)) +import Data.Maybe (fromJust) +import Data.String (fromString) import Data.Text qualified as T import Data.Text.IO qualified as T import CMark (commonmarkToHtml) import System.IO.Unsafe - -import Control.Monad -import Text.Megaparsec hiding (chunk) -import Text.Megaparsec.Char as C +import GHC.Generics import Err -import Lexing (Parser, symChar, keyWordStrs, symbol, parseit, withSource) +import IncState import Paths_dex (getDataFileName) -import PPrint () -import SourceInfo -import TraverseSourceInfo -import Types.Misc +import PPrint import Types.Source -import Util +import Util (unsnoc) + +-- === rendering results === + +-- RenderedOutputs, RenderedSourceBlock aren't 100% HTML themselves but the idea +-- is that they should be trivially convertable to JSON and sent over to the +-- client which can do the final rendering without much code or runtime work. + +type BlockId = Int +data RenderedSourceBlock = RenderedSourceBlock + { rsbLine :: Int + , rsbBlockId :: BlockId + , rsbLexemeList :: [SrcId] + , rsbHtml :: String } + deriving (Generic) + +data RenderedOutputs = RenderedOutputs + { rrHtml :: String + , rrLexemeSpans :: SpanMap + , rrHighlightMap :: HighlightMap + , rrHoverInfoMap :: HoverInfoMap + , rrErrorSrcIds :: [SrcId] } + deriving (Generic) + +renderOutputs :: Outputs -> RenderedOutputs +renderOutputs (Outputs outputs) = fold $ map renderOutput outputs + +renderOutput :: Output -> RenderedOutputs +renderOutput r = RenderedOutputs + { rrHtml = pprintHtml r + , rrLexemeSpans = computeSpanMap r + , rrHighlightMap = computeHighlights r + , rrHoverInfoMap = computeHoverInfo r + , rrErrorSrcIds = computeErrSrcIds r} + +renderSourceBlock :: BlockId -> SourceBlock -> RenderedSourceBlock +renderSourceBlock n b = RenderedSourceBlock + { rsbLine = sbLine b + , rsbBlockId = n + , rsbLexemeList = unsnoc $ lexemeList $ sbLexemeInfo b + , rsbHtml = renderHtml case sbContents b of + Misc (ProseBlock s) -> cdiv "prose-block" $ mdToHtml s + _ -> renderSpans n (sbLexemeInfo b) (sbText b) + } + +instance ToMarkup Outputs where + toMarkup (Outputs outs) = foldMap toMarkup outs + +instance ToMarkup Output where + toMarkup out = case out of + HtmlOut s -> preEscapedString s + SourceInfo _ -> mempty + Error _ -> cdiv "err-block" $ toHtml $ pprint out + _ -> cdiv "result-block" $ toHtml $ pprint out + +instance ToJSON RenderedOutputs +instance ToJSON RenderedSourceBlock + +instance Semigroup RenderedOutputs where + RenderedOutputs x1 y1 z1 w1 v1 <> RenderedOutputs x2 y2 z2 w2 v2 = + RenderedOutputs (x1<>x2) (y1<>y2) (z1<>z2) (w1<>w2) (v1<>v2) + +instance Monoid RenderedOutputs where + mempty = RenderedOutputs mempty mempty mempty mempty mempty + +-- === textual information on hover === + +type HoverInfo = String +newtype HoverInfoMap = HoverInfoMap (M.Map LexemeId HoverInfo) deriving (ToJSON, Semigroup, Monoid) + +computeHoverInfo :: Output -> HoverInfoMap +computeHoverInfo (SourceInfo (SITypeInfo m)) = HoverInfoMap $ fromTypeInfo m +computeHoverInfo _ = mempty + +-- === highlighting on hover === + +newtype SpanMap = SpanMap (M.Map SrcId LexemeSpan) deriving (ToJSON, Semigroup, Monoid) +newtype HighlightMap = HighlightMap (M.Map SrcId Highlights) deriving (ToJSON, Semigroup, Monoid) +type Highlights = [(HighlightType, SrcId)] +data HighlightType = HighlightGroup | HighlightLeaf deriving Generic + +instance ToJSON HighlightType + +computeErrSrcIds :: Output -> [SrcId] +computeErrSrcIds (Error err) = case err of + SearchFailure _ -> [] + InternalErr _ -> [] + ParseErr _ -> [] + SyntaxErr sid _ -> [sid] + NameErr sid _ -> [sid] + TypeErr sid _ -> [sid] + RuntimeErr -> [] + MiscErr _ -> [] +computeErrSrcIds _ = [] + +computeSpanMap :: Output -> SpanMap +computeSpanMap (SourceInfo (SIGroupTree (OverwriteWith tree))) = + execWriter $ go tree where + go :: GroupTree -> Writer SpanMap () + go t = do + tell $ SpanMap $ M.singleton (gtSrcId t) (gtSpan t) + mapM_ go $ gtChildren t +computeSpanMap _ = mempty + +computeHighlights :: Output -> HighlightMap +computeHighlights (SourceInfo (SIGroupTree (OverwriteWith tree))) = + execWriter $ go tree where + go :: GroupTree -> Writer HighlightMap () + go t = do + let children = gtChildren t + let highlights = children <&> \child -> + (getHighlightType (gtIsAtomicLexeme child), gtSrcId child) + forM_ children \child-> do + tell $ HighlightMap $ M.singleton (gtSrcId child) highlights + go child + + getHighlightType :: Bool -> HighlightType + getHighlightType True = HighlightLeaf + getHighlightType False = HighlightGroup +computeHighlights _ = mempty + +-- ----------------- cssSource :: T.Text cssSource = unsafePerformIO $ @@ -47,7 +172,7 @@ pprintHtml x = renderHtml $ toMarkup x progHtml :: (ToMarkup a, ToMarkup b) => [(a, b)] -> String progHtml blocks = renderHtml $ wrapBody $ map toHtmlBlock blocks - where toHtmlBlock (block,result) = toMarkup block <> toMarkup result + where toHtmlBlock (block,outputs) = toMarkup block <> toMarkup outputs wrapBody :: [Html] -> Html wrapBody blocks = docTypeHtml $ do @@ -65,118 +190,57 @@ wrapBody blocks = docTypeHtml $ do inner = foldMap (cdiv "cell") blocks jsSource = textValue $ javascriptSource <> "render(RENDER_MODE.STATIC);" -instance ToMarkup Result where - toMarkup (Result outs err) = foldMap toMarkup outs <> err' - where err' = case err of - Failure e -> cdiv "err-block" $ toHtml $ pprint e - Success () -> mempty - -instance ToMarkup Output where - toMarkup out = case out of - HtmlOut s -> preEscapedString s - _ -> cdiv "result-block" $ toHtml $ pprint out - -instance ToMarkup SourceBlock where - toMarkup block = case sbContents block of - (Misc (ProseBlock s)) -> cdiv "prose-block" $ mdToHtml s - TopDecl decl -> renderSpans decl block - Command _ g -> renderSpans g block - _ -> cdiv "code-block" $ highlightSyntax (sbText block) - mdToHtml :: T.Text -> Html mdToHtml s = preEscapedText $ commonmarkToHtml [] s cdiv :: String -> Html -> Html cdiv c inner = H.div inner ! class_ (stringValue c) --- === syntax highlighting === - -spanDelimitedCode :: SourceBlock -> [SrcPosCtx] -> Html -spanDelimitedCode block ctxs = - let (Just tree) = srcCtxsToTree block ctxs in - spanDelimitedCode' block tree - -spanDelimitedCode' :: SourceBlock -> SpanTree -> Html -spanDelimitedCode' block tree = treeToHtml (sbText block) tree - -treeToHtml :: T.Text -> SpanTree -> Html -treeToHtml source' tree = - let tree' = fillTreeAndAddTrivialLeaves (T.unpack source') tree in - treeToHtml' source' tree' - -treeToHtml' :: T.Text -> SpanTree -> Html -treeToHtml' source' tree = case tree of - Span (_, _, _) children -> - let body' = foldMap (treeToHtml' source') children in - H.span body' ! spanClass - LeafSpan (l, r, _) -> - let spanText = sliceText l r source' in - H.span (highlightSyntax spanText) ! spanLeaf - Trivia (l, r) -> - let spanText = sliceText l r source' in - highlightSyntax spanText - where - spanClass :: Attribute - spanClass = At.class_ "code-span" - - spanLeaf :: Attribute - spanLeaf = At.class_ "code-span-leaf" - -srcCtxsToSpanInfos :: SourceBlock -> [SrcPosCtx] -> [SpanPayload] -srcCtxsToSpanInfos block ctxs = - let blockOffset = sbOffset block in - let ctxs' = L.sort ctxs in - (0, maxBound, 0) : mapMaybe (convert' blockOffset) ctxs' - where convert' :: Int -> SrcPosCtx -> Maybe SpanPayload - convert' offset (SrcPosCtx (Just (l, r)) (Just spanId)) = Just (l - offset, r - offset, spanId + 1) - convert' _ _ = Nothing - -srcCtxsToTree :: SourceBlock -> [SrcPosCtx] -> Maybe SpanTree -srcCtxsToTree block ctxs = makeEmptySpanTree (srcCtxsToSpanInfos block ctxs) - -renderSpans :: HasSourceInfo a => a -> SourceBlock -> Html -renderSpans x block = - let x' = addSpanIds x in - let ctxs = gatherSourceInfo x' in - toHtml $ cdiv "code-block" $ spanDelimitedCode block ctxs - -highlightSyntax :: T.Text -> Html -highlightSyntax s = foldMap (uncurry syntaxSpan) classified - where classified = ignoreExcept $ parseit s (many (withSource classify) <* eof) - -syntaxSpan :: T.Text -> StrClass -> Html -syntaxSpan s NormalStr = toHtml s -syntaxSpan s c = H.span (toHtml s) ! class_ (stringValue className) - where - className = case c of - CommentStr -> "comment" - KeywordStr -> "keyword" - CommandStr -> "command" - SymbolStr -> "symbol" - TypeNameStr -> "type-name" - IsoSugarStr -> "iso-sugar" - WhitespaceStr -> "whitespace" - NormalStr -> error "Should have been matched already" - -data StrClass = NormalStr - | CommentStr | KeywordStr | CommandStr | SymbolStr | TypeNameStr - | IsoSugarStr | WhitespaceStr - -classify :: Parser StrClass -classify = - (try (char ':' >> lowerWord) >> return CommandStr) - <|> (symbol "-- " >> manyTill anySingle (void eol <|> eof) >> return CommentStr) - <|> (do s <- lowerWord - return $ if s `elem` keyWordStrs then KeywordStr else NormalStr) - <|> (upperWord >> return TypeNameStr) - <|> try (char '#' >> (char '?' <|> char '&' <|> char '|' <|> pure ' ') - >> lowerWord >> return IsoSugarStr) - <|> (some symChar >> return SymbolStr) - <|> (some space1 >> return WhitespaceStr) - <|> (anySingle >> return NormalStr) - -lowerWord :: Parser String -lowerWord = (:) <$> lowerChar <*> many alphaNumChar - -upperWord :: Parser String -upperWord = (:) <$> upperChar <*> many alphaNumChar +renderSpans :: BlockId -> LexemeInfo -> T.Text -> Markup +renderSpans blockId lexInfo sourceText = cdiv "code-block" do + runTextWalkerT sourceText do + forM_ (lexemeList lexInfo) \sourceId -> do + let (lexemeTy, (l, r)) = fromJust $ M.lookup sourceId (lexemeInfo lexInfo) + takeTo l >>= emitSpan Nothing (Just "comment") + takeTo r >>= emitSpan (Just (blockId, sourceId)) (lexemeClass lexemeTy) + takeRest >>= emitSpan Nothing (Just "comment") + +emitSpan :: Maybe (BlockId, SrcId) -> Maybe String -> T.Text -> TextWalker () +emitSpan maybeSrcId className t = lift do + let classAttr = case className of + Nothing -> mempty + Just c -> class_ (stringValue c) + let idAttr = case maybeSrcId of + Nothing -> mempty + Just (bid, SrcId sid) -> At.id (fromString $ "span_" ++ show bid ++ "_"++ show sid) + H.span (toHtml t) ! classAttr ! idAttr + +lexemeClass :: LexemeType -> Maybe String +lexemeClass = \case + Keyword -> Just "keyword" + Symbol -> Just "symbol" + TypeName -> Just "type-name" + LowerName -> Nothing + UpperName -> Nothing + LiteralLexeme -> Just "literal" + StringLiteralLexeme -> Nothing + MiscLexeme -> Nothing + +type TextWalker a = StateT (Int, T.Text) MarkupM a + +runTextWalkerT :: T.Text -> TextWalker a -> MarkupM a +runTextWalkerT t cont = evalStateT cont (0, t) + +-- index is the *absolute* index, from the very beginning +takeTo :: Int -> TextWalker T.Text +takeTo startPos = do + (curPos, curText) <- get + let (prefix, remText) = T.splitAt (startPos- curPos) curText + put (startPos, remText) + return prefix + +takeRest :: TextWalker T.Text +takeRest = do + (curPos, curText) <- get + put (curPos + T.length curText, mempty) + return curText diff --git a/src/lib/Runtime.hs b/src/lib/Runtime.hs index 1ea5dad66..885088c21 100644 --- a/src/lib/Runtime.hs +++ b/src/lib/Runtime.hs @@ -24,17 +24,14 @@ import Control.Monad import Control.Concurrent import Control.Exception hiding (throw) import qualified Control.Exception as E -import qualified System.Environment as E import Err -import Logging -import Util (measureSeconds) +import MonadUtil import PPrint () -import CUDA (synchronizeCUDA) -import Types.Core hiding (DexDestructor) +import Types.Top hiding (DexDestructor) +import Types.Source hiding (CInt) import Types.Primitives -import Types.Misc -- === One-shot evaluation === @@ -55,75 +52,44 @@ data BenchRequirement = data LLVMCallable = LLVMCallable { nativeFun :: NativeFunction - , benchRequired :: BenchRequirement , logger :: PassLogger - , resultTypes :: [BaseType] - } + , resultTypes :: [BaseType] } -- The NativeFunction needs to have been compiled with EntryFunCC. callEntryFun :: LLVMCallable -> [LitVal] -> IO [LitVal] -callEntryFun LLVMCallable{nativeFun, benchRequired, logger, resultTypes} args = do +callEntryFun LLVMCallable{nativeFun, logger, resultTypes} args = do withPipeToLogger logger \fd -> allocaCells (length args) \argsPtr -> allocaCells (length resultTypes) \resultPtr -> do storeLitVals argsPtr args let fPtr = castFunPtr $ nativeFunPtr nativeFun - evalTime <- checkedCallFunPtr fd argsPtr resultPtr fPtr + checkedCallFunPtr fd argsPtr resultPtr fPtr results <- loadLitVals resultPtr resultTypes - case benchRequired of - NoBench -> logSkippingFilter logger [EvalTime evalTime Nothing] - DoBench shouldSyncCUDA -> do - let sync = when shouldSyncCUDA $ synchronizeCUDA - (avgTime, benchRuns, totalTime) <- runBench do - let (CInt fd') = fdFD fd - exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - unless (exitCode == 0) $ throw RuntimeErr "" - freeLitVals resultPtr resultTypes - sync - logSkippingFilter logger [EvalTime avgTime (Just (benchRuns, totalTime + evalTime))] return results {-# SCC callEntryFun #-} -checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO Double +checkedCallFunPtr :: FD -> Ptr () -> Ptr () -> DexExecutable -> IO () checkedCallFunPtr fd argsPtr resultPtr fPtr = do let (CInt fd') = fdFD fd - (exitCode, duration) <- measureSeconds $ do - exitCode <- callFunPtr fPtr fd' argsPtr resultPtr - return exitCode - unless (exitCode == 0) $ throw RuntimeErr "" - return duration + exitCode <- callFunPtr fPtr fd' argsPtr resultPtr + unless (exitCode == 0) $ throwErr RuntimeErr withPipeToLogger :: PassLogger -> (FD -> IO a) -> IO a withPipeToLogger logger writeAction = do result <- snd <$> withPipe - (\h -> readStream h \s -> logSkippingFilter logger [TextOut s]) + (\h -> readStream h \s -> ioLogAction logger $ Outputs [TextOut s]) (\h -> handleToFd h >>= writeAction) case result of Left e -> E.throw e Right ans -> return ans -runBench :: IO () -> IO (Double, Int, Double) -runBench run = do - exampleDuration <- snd <$> measureSeconds run - test_mode <- (Just "t" ==) <$> E.lookupEnv "DEX_TEST_MODE" - let timeBudget = (2 - exampleDuration) `max` 0 -- seconds - let benchRuns = if test_mode - then 0 - else (ceiling $ timeBudget / exampleDuration) :: Int - totalTime' <- liftM snd $ measureSeconds $ do - forM_ [1..benchRuns] $ const run - let totalTime = totalTime' + exampleDuration - avgTime = totalTime / (fromIntegral $ benchRuns + 1) - - return (avgTime, benchRuns + 1, totalTime) - -- === serializing scalars === loadLitVals :: MonadIO m => Ptr () -> [BaseType] -> m [LitVal] loadLitVals p types = zipWithM loadLitVal (ptrArray p) types -freeLitVals :: MonadIO m => Ptr () -> [BaseType] -> m () -freeLitVals p types = zipWithM_ freeLitVal (ptrArray p) types +_freeLitVals :: MonadIO m => Ptr () -> [BaseType] -> m () +_freeLitVals p types = zipWithM_ freeLitVal (ptrArray p) types storeLitVals :: MonadIO m => Ptr () -> [LitVal] -> m () storeLitVals p xs = zipWithM_ storeLitVal (ptrArray p) xs diff --git a/src/lib/RuntimePrint.hs b/src/lib/RuntimePrint.hs index 4a4c2c6a5..dd12d67a5 100644 --- a/src/lib/RuntimePrint.hs +++ b/src/lib/RuntimePrint.hs @@ -16,6 +16,7 @@ import IRVariants import MTL1 import Name import CheapReduction +import PPrint import Types.Core import Types.Source import Types.Primitives @@ -27,13 +28,13 @@ newtype Printer (n::S) (a :: *) = Printer { runPrinter' :: ReaderT1 (Atom CoreIR , Fallible, ScopeReader, MonadFail, EnvExtender, CBuilder, ScopableBuilder CoreIR) type Print n = Printer n () -showAny :: EnvReader m => Atom CoreIR n -> m n (Block CoreIR n) +showAny :: EnvReader m => Atom CoreIR n -> m n (CExpr n) showAny x = liftPrinter $ showAnyRec (sink x) liftPrinter :: EnvReader m => (forall l. (DExt n l, Emits l) => Print l) - -> m n (CBlock n) + -> m n (CExpr n) liftPrinter cont = liftBuilder $ buildBlock $ withBuffer \buf -> runReaderT1 buf (runPrinter' cont) @@ -58,30 +59,32 @@ emitCharLit c = emitChar $ charRepVal c showAnyRec :: forall n. Emits n => CAtom n -> Print n showAnyRec atom = case getType atom of - -- hack to print chars nicely. TODO: make `Char` a newtype - TC t -> case t of - BaseType bt -> case bt of - Vector _ _ -> error "not implemented" - PtrType _ -> printTypeOnly "pointer" - Scalar _ -> do - (n, tab) <- fromPair =<< emitExpr (PrimOp $ MiscOp $ ShowScalar atom) - logicalTabTy <- finTabTyCore (NewtypeCon NatCon n) CharRepTy - tab' <- emitExpr $ PrimOp $ MiscOp $ UnsafeCoerce logicalTabTy tab - emitCharTab tab' - -- TODO: we could do better than this but it's not urgent because raw sum types - -- aren't user-facing. - SumType _ -> printAsConstant - RefType _ _ -> printTypeOnly "reference" - HeapType -> printAsConstant - ProdType _ -> do - xs <- getUnpacked atom - parens $ sepBy ", " $ map rec xs - -- TODO: traverse the type and print out data components - TypeKind -> printAsConstant - ProjectEltTy _ _ _ -> error "not implemented" + TyCon con -> showAnyTyCon con atom + StuckTy _ e -> error $ "unexpected stuck type expression: " ++ pprint e + +showAnyTyCon :: forall n. Emits n => TyCon CoreIR n -> CAtom n -> Print n +showAnyTyCon tyCon atom = case tyCon of + BaseType bt -> case bt of + Vector _ _ -> error "not implemented" + PtrType _ -> printTypeOnly "pointer" + Scalar _ -> do + (n, tab) <- fromPair =<< emit (ShowScalar atom) + logicalTabTy <- finTabTyCore (Con $ NewtypeCon NatCon n) CharRepTy + tab' <- emit $ UnsafeCoerce logicalTabTy tab + emitCharTab tab' + -- TODO: we could do better than this but it's not urgent because raw sum types + -- aren't user-facing. + SumType _ -> printAsConstant + RefType _ _ -> printTypeOnly "reference" + HeapType -> printAsConstant + ProdType _ -> do + xs <- getUnpacked atom + parens $ sepBy ", " $ map rec xs + -- TODO: traverse the type and print out data components + TypeKind -> printAsConstant Pi _ -> printTypeOnly "function" TabPi _ -> brackets $ forEachTabElt atom \iOrd x -> do - isFirst <- ieq iOrd (NatVal 0) + isFirst <- emit $ BinOp (ICmp Equal) iOrd (NatVal 0) void $ emitIf isFirst UnitTy (return UnitVal) (emitLit ", " >> return UnitVal) rec x NewtypeTyCon tc -> case tc of @@ -89,12 +92,12 @@ showAnyRec atom = case getType atom of Nat -> do n <- unwrapNewtype atom -- Cast to Int so that it prints in decimal instead of hex - let intTy = TC (BaseType (Scalar Int64Type)) - emitExpr (PrimOp $ MiscOp $ CastOp intTy n) >>= rec + let intTy = toType $ BaseType (Scalar Int64Type) + emit (CastOp intTy n) >>= rec EffectRowKind -> printAsConstant -- hack to print strings nicely. TODO: make `Char` a newtype - UserADTType "List" _ (TyConParams [Explicit] [Type Word8Ty]) -> do - charTab <- normalizeNaryProj [ProjectProduct 1, UnwrapNewtype] atom + UserADTType "List" _ (TyConParams [Explicit] [Con (TyConAtom (BaseType (Scalar (Word8Type))))]) -> do + charTab <- applyProjections [ProjectProduct 1, UnwrapNewtype] atom emitCharLit '"' emitCharTab charTab emitCharLit '"' @@ -107,7 +110,7 @@ showAnyRec atom = case getType atom of showDataCon (sink $ cons !! i) arg return UnitVal StructFields fields -> do - emitLit tySourceName + emitLit $ pprint tySourceName parens do sepBy ", " $ (enumerate fields) <&> \(i, _) -> rec =<< projectStruct i atom @@ -115,20 +118,19 @@ showAnyRec atom = case getType atom of showDataCon :: Emits n' => DataConDef n' -> CAtom n' -> Print n' showDataCon (DataConDef sn _ _ projss) arg = do case projss of - [] -> emitLit sn + [] -> emitLit $ pprint sn _ -> parens do - emitLit (sn ++ " ") + emitLit (pprint sn ++ " ") sepBy " " $ projss <&> \projs -> -- we use `init` to strip off the `UnwrapCompoundNewtype` since -- we're already under the case alternative - rec =<< normalizeNaryProj (init projs) arg + rec =<< applyProjections (init projs) arg DepPairTy _ -> parens do (x, y) <- fromPair atom rec x >> emitLit " ,> " >> rec y -- Done well, this could let you inspect the results of dictionary synthesis -- and maybe even debug synthesis failures. DictTy _ -> printAsConstant - TyVar v -> error $ "unexpected type variable: " ++ pprint v where rec :: Emits n' => CAtom n' -> Print n' rec = showAnyRec @@ -162,18 +164,18 @@ withBuffer => (forall l . (Emits l, DExt n l) => CAtom l -> BuilderM CoreIR l ()) -> BuilderM CoreIR n (CAtom n) withBuffer cont = do - lam <- withFreshBinder "h" (TC HeapType) \h -> do - bufTy <- bufferTy (Var $ binderVar h) + lam <- withFreshBinder "h" (TyCon HeapType) \h -> do + bufTy <- bufferTy (toAtom $ binderVar h) withFreshBinder "buf" bufTy \b -> do - let eff = OneEffect (RWSEffect State (Var $ sink $ binderVar h)) + let eff = OneEffect (RWSEffect State (toAtom $ sink $ binderVar h)) body <- buildBlock do - cont $ sink $ Var $ binderVar b + cont $ sink $ toAtom $ binderVar b return UnitVal let binders = BinaryNest h b let expls = [Inferred Nothing Unify, Explicit] let piTy = CorePiType ExplicitApp expls binders $ EffTy eff UnitTy let lam = LamExpr (BinaryNest h b) body - return $ Lam $ CoreLamExpr piTy lam + return $ toAtom $ CoreLamExpr piTy lam applyPreludeFunction "with_stack_internal" [lam] bufferTy :: EnvReader m => CAtom n -> m n (CType n) @@ -185,7 +187,7 @@ bufferTy h = do extendBuffer :: (Emits n, CBuilder m) => CAtom n -> CAtom n -> m n () extendBuffer buf tab = do RefTy h _ <- return $ getType buf - TabPi t <- return $ getType tab + TyCon (TabPi t) <- return $ getType tab n <- applyIxMethodCore Size (tabIxType t) [] void $ applyPreludeFunction "stack_extend_internal" [n, h, buf, tab] @@ -198,38 +200,36 @@ pushBuffer buf x = do stringLitAsCharTab :: (Emits n, CBuilder m) => String -> m n (CAtom n) stringLitAsCharTab s = do t <- finTabTyCore (NatVal $ fromIntegral $ length s) CharRepTy - emitExpr $ TabCon Nothing t (map charRepVal s) + emit $ TabCon Nothing t (map charRepVal s) finTabTyCore :: (Fallible1 m, EnvReader m) => CAtom n -> CType n -> m n (CType n) -finTabTyCore n eltTy = do - d <- mkDictAtom $ IxFin n - return $ IxType (FinTy n) (IxDictAtom d) ==> eltTy +finTabTyCore n eltTy = return $ IxType (FinTy n) (DictCon $ IxFin n) ==> eltTy -getPreludeFunction :: EnvReader m => String -> m n (CAtom n) +getPreludeFunction :: EnvReader m => SourceName -> m n (CAtom n) getPreludeFunction sourceName = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of - UAtomVar v -> Var <$> toAtomVar v + UAtomVar v -> toAtom <$> toAtomVar v _ -> notfound Nothing -> notfound - where notfound = error $ "Function not defined: " ++ sourceName + where notfound = error $ "Function not defined: " ++ pprint sourceName -applyPreludeFunction :: (Emits n, CBuilder m) => String -> [CAtom n] -> m n (CAtom n) +applyPreludeFunction :: (Emits n, CBuilder m) => SourceName -> [CAtom n] -> m n (CAtom n) applyPreludeFunction name args = do f <- getPreludeFunction name naryApp f args -strType :: EnvReader m => m n (CType n) -strType = constructPreludeType "List" $ TyConParams [Explicit] [Type CharRepTy] +strType :: forall n m. EnvReader m => m n (CType n) +strType = constructPreludeType "List" $ TyConParams [Explicit] [toAtom (CharRepTy :: CType n)] -constructPreludeType :: EnvReader m => String -> TyConParams n -> m n (CType n) +constructPreludeType :: EnvReader m => SourceName -> TyConParams n -> m n (CType n) constructPreludeType sourceName params = do lookupSourceMap sourceName >>= \case Just uvar -> case uvar of - UTyConVar v -> return $ TypeCon sourceName v params + UTyConVar v -> return $ toType $ UserADTType sourceName v params _ -> notfound Nothing -> notfound - where notfound = error $ "Type constructor not defined: " ++ sourceName + where notfound = error $ "Type constructor not defined: " ++ pprint sourceName forEachTabElt :: (Emits n, ScopableBuilder CoreIR m) @@ -237,10 +237,10 @@ forEachTabElt -> (forall l. (Emits l, DExt n l) => CAtom l -> CAtom l -> m l ()) -> m n () forEachTabElt tab cont = do - TabPi t <- return $ getType tab + TyCon (TabPi t) <- return $ getType tab let ixTy = tabIxType t void $ buildFor "i" Fwd ixTy \i -> do - x <- tabApp (sink tab) (Var i) - i' <- applyIxMethodCore Ordinal (sink ixTy) [Var i] + x <- tabApp (sink tab) (toAtom i) + i' <- applyIxMethodCore Ordinal (sink ixTy) [toAtom i] cont i' x return $ UnitVal diff --git a/src/lib/Simplify.hs b/src/lib/Simplify.hs index 86e395e80..f44aa1099 100644 --- a/src/lib/Simplify.hs +++ b/src/lib/Simplify.hs @@ -10,12 +10,10 @@ module Simplify ( simplifyTopBlock, simplifyTopFunction, ReconstructAtom (..), applyReconTop, linearizeTopFun, SimplifiedTopLam (..)) where -import Control.Applicative import Control.Category ((>>>)) import Control.Monad import Control.Monad.Reader import Data.Maybe -import Data.Text.Prettyprint.Doc (Pretty (..), hardline) import Builder import CheapReduction @@ -27,12 +25,13 @@ import IRVariants import Linearize import Name import Subst -import Optimize (peepholeOp) +import PPrint import QueryType import RuntimePrint import Transpose import Types.Core import Types.Source +import Types.Top import Types.Primitives import Util (enumerate) @@ -68,134 +67,206 @@ tryAsDataAtom atom = do isData ty >>= \case False -> return Nothing True -> Just <$> do - repAtom <- go atom + repAtom <- dropSubst $ toDataAtom atom return (repAtom, ty) - where - go :: Emits n => CAtom n -> SimplifyM i n (SAtom n) - go = \case - Var v -> lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> go x - _ -> error "Shouldn't have irreducible top names left" - Con con -> Con <$> case con of - Lit v -> return $ Lit v - ProdCon xs -> ProdCon <$> mapM go xs - SumCon tys tag x -> SumCon <$> mapM getRepType tys <*> pure tag <*> go x - HeapVal -> return HeapVal - PtrVar t v -> return $ PtrVar t v - DepPair x y ty -> do - DepPairTy ty' <- getRepType $ DepPairTy ty - DepPair <$> go x <*> go y <*> pure ty' - ProjectElt _ UnwrapNewtype x -> go x - -- TODO: do we need to think about a case like `fst (1, \x.x)`, where - -- the projection is data but the argument isn't? - ProjectElt _ (ProjectProduct i) x -> normalizeProj (ProjectProduct i) =<< go x - NewtypeCon _ x -> go x - SimpInCore x -> case x of - LiftSimp _ x' -> return x' - LiftSimpFun _ _ -> notData - TabLam _ tabLam -> forceTabLam tabLam - ACase scrut alts resultTy -> forceACase scrut alts resultTy - Lam _ -> notData - DictCon _ _ -> notData - Eff _ -> notData - DictHole _ _ _ -> notData - TypeAsAtom _ -> notData - where - notData = error $ "Not runtime-representable data: " ++ pprint atom - -forceTabLam :: Emits n => TabLamExpr n -> SimplifyM i n (SAtom n) -forceTabLam (PairE ixTy (Abs b ab)) = - buildFor (getNameHint b) Fwd ixTy \v -> do - result <- applyRename (b@>(atomVarName v)) ab >>= emitDecls - toDataAtomIgnoreRecon result - -type NaryTabLamExpr = Abs (Nest SBinder) (Abs (Nest SDecl) CAtom) - -fromNaryTabLam :: Int -> CAtom n -> Maybe (Int, NaryTabLamExpr n) -fromNaryTabLam maxDepth | maxDepth <= 0 = error "expected positive number of args" -fromNaryTabLam maxDepth = \case - SimpInCore (TabLam _ (PairE _ (Abs b body))) -> - extend <|> (Just $ (1, Abs (Nest b Empty) body)) - where - extend = case body of - Abs Empty lam | maxDepth > 1 -> do - (d, Abs (Nest b2 bs2) body2) <- fromNaryTabLam (maxDepth - 1) lam - return $ (d + 1, Abs (Nest b (Nest b2 bs2)) body2) - _ -> Nothing - _ -> Nothing - -forceACase :: Emits n => SAtom n -> [Abs SBinder CAtom n] -> CType n -> SimplifyM i n (SAtom n) -forceACase scrut alts resultTy = do - resultTy' <- getRepType resultTy - buildCase scrut resultTy' \i arg -> do - Abs b result <- return $ alts !! i - applySubst (b@>SubstVal arg) result >>= toDataAtomIgnoreRecon + +data WithSubst (e::E) (o::S) where + WithSubst :: Subst AtomSubstVal i o -> e i -> WithSubst e o + +type ACase = SStuck `PairE` ListE (Abs SBinder CAtom) `PairE` CType + +data ConcreteCAtom (n::S) = + CCCon (WithSubst (Con CoreIR) n) + | CCLiftSimp (CType n) (Stuck SimpIR n) + | CCFun (ConcreteCFun n) + | CCTabLam (WithSubst TabLamExpr n) + | CCACase (WithSubst ACase n) + +data ConcreteCFun (n::S) = + CCLiftSimpFun (CorePiType n) (LamExpr SimpIR n) + | CCNoInlineFun (CAtomVar n) (CType n) (CAtom n) + | CCFFIFun (CorePiType n) (TopFunName n) + +forceConstructor :: CAtom i -> SimplifyM i o (ConcreteCAtom o) +forceConstructor atom = withDistinct case atom of + Stuck _ stuck -> forceStuck stuck + Con con -> do + subst <- getSubst + return $ CCCon $ WithSubst subst con + +forceStuck :: forall i o . CStuck i -> SimplifyM i o (ConcreteCAtom o) +forceStuck stuck = withDistinct case stuck of + Var v -> lookupSubstM (atomVarName v) >>= \case + SubstVal x -> dropSubst $ forceConstructor x + Rename v' -> lookupAtomName v' >>= \case + LetBound (DeclBinding _ (Atom x)) -> dropSubst $ forceConstructor x + NoinlineFun t f -> do + v'' <- toAtomVar v' + return $ CCFun $ CCNoInlineFun v'' t f + FFIFunBound t f -> return $ CCFun $ CCFFIFun t f + _ -> error "shouldn't have other CVars left" + LiftSimp _ x -> do + -- the subst should be rename-only for `x`. We should make subst IR-specific + s <- getSubst + let s' = newSubst \v -> case s ! v of + SubstVal _ -> error "subst should be rename-only for SimpIR vars" -- TODO: make subst IR-specific + Rename v' -> v' + x' <- runSubstReaderT s' $ renameM x + returnLifted x' + -- We "thunk" ACase rather than forcing it because different use-cases require different ways to force it + ACase e alts resultTy -> do + subst <- getSubst + return $ CCACase $ WithSubst subst $ e `PairE` ListE alts `PairE` resultTy + TabLam e -> do + subst <- getSubst + return $ CCTabLam $ WithSubst subst e + StuckProject i x -> forceStuck x >>= \case + CCLiftSimp _ x' -> returnLifted $ StuckProject i x' + CCCon (WithSubst s con) -> withSubst s case con of + ProdCon xs -> forceConstructor (xs!!i) + DepPair l r _ -> forceConstructor ([l, r]!!i) + _ -> error "not a product" + CCACase x' -> pushUnderACase x' \x'' -> reduceProj i x'' + CCFun _ -> error "not a product" + CCTabLam _ -> error "not a product" + StuckTabApp f x -> forceStuck f >>= \case + CCLiftSimp _ f' -> do + x' <- toDataAtom x + returnLifted $ StuckTabApp f' x' + CCTabLam (WithSubst s (PairE _ (Abs b body))) -> do + x' <- toDataAtom x + result <- withSubst s $ extendSubst (b@>SubstVal x') $ substM body + dropSubst $ forceConstructor result + CCACase f' -> pushUnderACase f' \f'' -> reduceTabApp f'' =<< substM x + CCCon _ -> error "not a table" + CCFun _ -> error "not a table" + StuckUnwrap x -> forceStuck x >>= \case + CCCon (WithSubst s con) -> case con of + NewtypeCon _ x' -> withSubst s $ forceConstructor x' + _ -> error "not a newtype" + CCLiftSimp _ x' -> returnLifted x' + CCACase x' -> pushUnderACase x' \x'' -> reduceUnwrap x'' + CCFun _ -> error "not a newtype" + CCTabLam _ -> error "not a newtype" + InstantiatedGiven _ _ -> error "shouldn't have this left" + SuperclassProj _ _ -> error "shouldn't have this left" + PtrVar ty p -> do + p' <- substM p + returnLifted $ PtrVar ty p' + LiftSimpFun t f -> CCFun <$> (CCLiftSimpFun <$> substM t <*> substM f) + where + returnLifted :: SStuck o -> SimplifyM i o (ConcreteCAtom o) + returnLifted s = do + resultTy <- getType <$> substMStuck stuck + return $ CCLiftSimp resultTy s + +pushUnderACase + :: WithSubst ACase o + -> (forall o'. DExt o o' => CAtom o' -> SimplifyM i o' (CAtom o')) + -> SimplifyM i o (ConcreteCAtom o) +pushUnderACase _ _ = undefined +-- pushUnderACase (WithSubst s (scrut `PairE` ListE alts `PairE` resultTy)) cont = undefined +-- TODO: make a buildACase to use here and elsewhere in Simplify. Maybe in CheapReduce too? + + +forceACase + :: Emits o => WithSubst ACase o + -> (forall o'. (Emits o', DExt o o') => ConcreteCAtom o' -> SimplifyM i o' (CAtom o')) + -> SimplifyM i o (CAtom o) +forceACase (WithSubst subst (scrut `PairE` ListE alts `PairE` resultTy)) cont = do + resultTy' <- withSubst subst $ substM resultTy + scrut' <- withSubst subst $ substMStuck scrut + defuncCase scrut' resultTy' \i x -> do + Abs b body <- return $ alts !! i + body' <- withSubst (sink subst) $ extendSubst (b@>SubstVal x) $ forceConstructor body + cont body' tryGetRepType :: Type CoreIR n -> SimplifyM i n (Maybe (SType n)) tryGetRepType t = isData t >>= \case False -> return Nothing - True -> Just <$> getRepType t - -getRepType :: Type CoreIR n -> SimplifyM i n (SType n) -getRepType ty = go ty where - go :: Type CoreIR n -> SimplifyM i n (SType n) - go = \case - TC con -> TC <$> case con of - BaseType b -> return $ BaseType b - ProdType ts -> ProdType <$> mapM go ts - SumType ts -> SumType <$> mapM go ts - RefType h a -> RefType <$> toDataAtomIgnoreReconAssumeNoDecls h <*> go a - TypeKind -> error $ notDataType - HeapType -> return $ HeapType - DepPairTy (DepPairType expl b@(_:>l) r) -> do - l' <- go l - withFreshBinder (getNameHint b) l' \b' -> do - x <- liftSimpAtom (sink l) (Var $ binderVar b') - r' <- go =<< applySubst (b@>SubstVal x) r - return $ DepPairTy $ DepPairType expl b' r' - TabPi tabTy -> do - let ixTy = tabIxType tabTy - IxType t' d' <- simplifyIxType ixTy - withFreshBinder (getNameHint tabTy) t' \b' -> do - x <- liftSimpAtom (sink $ ixTypeType ixTy) (Var $ binderVar b') - bodyTy' <- go =<< instantiate (sink tabTy) [x] - return $ TabPi $ TabPiType d' b' bodyTy' - NewtypeTyCon con -> do - (_, ty') <- unwrapNewtypeType con - go ty' - Pi _ -> error notDataType - DictTy _ -> error notDataType - TyVar _ -> error "Shouldn't have type variables in CoreIR IR with SimpIR builder names" - ProjectEltTy _ _ _ -> error "Shouldn't have this left" - where notDataType = "Not a type of runtime-representable data: " ++ pprint ty - -toDataAtom :: Emits n => CAtom n -> SimplifyM i n (SAtom n, Type CoreIR n) -toDataAtom x = tryAsDataAtom x >>= \case - Just x' -> return x' - Nothing -> error $ "Not a data atom: " ++ pprint x - -simplifyDataAtom :: Emits o => CAtom i -> SimplifyM i o (SAtom o) -simplifyDataAtom x = toDataAtomIgnoreRecon =<< simplifyAtom x - -toDataAtomIgnoreRecon :: Emits n => CAtom n -> SimplifyM i n (SAtom n) -toDataAtomIgnoreRecon x = fst <$> toDataAtom x - -toDataAtomIgnoreReconAssumeNoDecls :: CAtom n -> SimplifyM i n (SAtom n) -toDataAtomIgnoreReconAssumeNoDecls x = do - Abs decls result <- buildScoped $ fst <$> toDataAtom (sink x) + True -> Just <$> dropSubst (getRepType t) + +getRepType :: Type CoreIR i -> SimplifyM i o (SType o) +getRepType (StuckTy _ stuck) = + substMStuck stuck >>= \case + Stuck _ _ -> error "shouldn't have stuck CType after substitution" + Con (TyConAtom tyCon) -> dropSubst $ getRepType (TyCon tyCon) + Con _ -> error "not a type" +getRepType (TyCon con) = case con of + BaseType b -> return $ toType $ BaseType b + ProdType ts -> toType . ProdType <$> mapM getRepType ts + SumType ts -> toType . SumType <$> mapM getRepType ts + RefType h a -> toType <$> (RefType <$> toDataAtomAssumeNoDecls h <*> getRepType a) + HeapType -> return $ toType HeapType + DepPairTy (DepPairType expl b r) -> do + withSimplifiedBinder b \b' -> do + r' <- getRepType r + return $ toType $ DepPairType expl b' r' + TabPi (TabPiType ixDict b r) -> do + ixDict' <- simplifyIxDict ixDict + withSimplifiedBinder b \b' -> do + r' <- getRepType r + return $ toType $ TabPi $ TabPiType ixDict' b' r' + NewtypeTyCon con' -> do + (_, ty') <- unwrapNewtypeType =<< substM con' + dropSubst $ getRepType ty' + Pi _ -> error notDataType + DictTy _ -> error notDataType + TypeKind -> error notDataType + where notDataType = "Not a type of runtime-representable data" + +toDataAtom :: CAtom i -> SimplifyM i o (SAtom o) +toDataAtom (Con con) = case con of + Lit v -> return $ toAtom $ Lit v + ProdCon xs -> toAtom . ProdCon <$> mapM rec xs + SumCon tys tag x -> toAtom <$> (SumCon <$> mapM getRepType tys <*> pure tag <*> rec x) + HeapVal -> return $ toAtom HeapVal + DepPair x y ty -> do + TyCon (DepPairTy ty') <- getRepType $ TyCon $ DepPairTy ty + toAtom <$> (DepPair <$> rec x <*> rec y <*> pure ty') + NewtypeCon _ x -> rec x + Lam _ -> notData + DictConAtom _ -> notData + Eff _ -> notData + TyConAtom _ -> notData + where + rec = toDataAtom + notData = error $ "Not runtime-representable data" +toDataAtom (Stuck _ stuck) = forceStuck stuck >>= \case + CCCon (WithSubst s con) -> withSubst s $ toDataAtom (Con con) + CCLiftSimp _ e -> mkStuck e + CCFun _ -> notData + CCACase _ -> notData -- TODO: make sure we observe this invariant" + CCTabLam _ -> notData -- TODO: make sure we observe this invariant" + where notData = error $ "Not runtime-representable data" + +toDataAtomAssumeNoDecls :: CAtom i -> SimplifyM i o (SAtom o) +toDataAtomAssumeNoDecls x = do + Abs decls result <- buildScoped $ toDataAtom x case decls of Empty -> return result _ -> error "unexpected decls" +withSimplifiedBinder + :: CBinder i i' + -> (forall o'. DExt o o' => Binder SimpIR o o' -> SimplifyM i' o' a) + -> SimplifyM i o a +withSimplifiedBinder (b:>ty) cont = do + tySimp <- getRepType ty + tyCore <- substM ty + withFreshBinder (getNameHint b) tySimp \b' -> do + x <- liftSimpAtom (sink tyCore) (toAtom $ binderVar b') + extendSubst (b@>SubstVal x) $ cont b' + withSimplifiedBinders :: Nest (Binder CoreIR) o any -> (forall o'. DExt o o' => Nest (Binder SimpIR) o o' -> [CAtom o'] -> SimplifyM i o' a) -> SimplifyM i o a withSimplifiedBinders Empty cont = getDistinct >>= \Distinct -> cont Empty [] withSimplifiedBinders (Nest (bCore:>ty) bsCore) cont = do - simpTy <- getRepType ty + simpTy <- dropSubst $ getRepType ty withFreshBinder (getNameHint bCore) simpTy \bSimp -> do - x <- liftSimpAtom (sink ty) (Var $ binderVar bSimp) + x <- liftSimpAtom (sink ty) (toAtom $ binderVar bSimp) -- TODO: carry a substitution instead of doing N^2 work like this Abs bsCore' UnitE <- applySubst (bCore@>SubstVal x) (EmptyAbs bsCore) withSimplifiedBinders bsCore' \bsSimp xs -> @@ -241,13 +312,13 @@ deriving instance ScopableBuilder SimpIR (SimplifyM i) -- === Top-level API === data SimplifiedTopLam n = SimplifiedTopLam (STopLam n) (ReconstructAtom n) -data SimplifiedBlock n = SimplifiedBlock (SBlock n) (ReconstructAtom n) +data SimplifiedBlock n = SimplifiedBlock (SExpr n) (ReconstructAtom n) simplifyTopBlock :: (TopBuilder m, Mut n) => TopBlock CoreIR n -> m n (SimplifiedTopLam n) simplifyTopBlock (TopLam _ _ (LamExpr Empty body)) = do SimplifiedBlock block recon <- liftSimplifyM do - {-# SCC "Simplify" #-} buildSimplifiedBlock $ simplifyBlock body + {-# SCC "Simplify" #-} buildSimplifiedBlock $ simplifyExpr body topLam <- asTopLam $ LamExpr Empty block return $ SimplifiedTopLam topLam recon simplifyTopBlock _ = error "not a block (nullary lambda)" @@ -263,7 +334,7 @@ applyReconTop :: (EnvReader m, Fallible1 m) => ReconstructAtom n -> SAtom n -> m applyReconTop = applyRecon instance GenericE SimplifiedBlock where - type RepE SimplifiedBlock = PairE SBlock ReconstructAtom + type RepE SimplifiedBlock = PairE SExpr ReconstructAtom fromE (SimplifiedBlock block recon) = PairE block recon {-# INLINE fromE #-} toE (PairE block recon) = SimplifiedBlock block recon @@ -272,13 +343,6 @@ instance GenericE SimplifiedBlock where instance SinkableE SimplifiedBlock instance RenameE SimplifiedBlock instance HoistableE SimplifiedBlock -instance CheckableE SimpIR SimplifiedBlock where - checkE (SimplifiedBlock block recon) = do - block' <- renameM block - effTy <- blockEffTy block' -- TODO: store this in the simplified block instead - block'' <- dropSubst $ checkBlock effTy block' - recon' <- renameM recon -- TODO: CheckableE instance for the recon too - return $ SimplifiedBlock block'' recon' instance Pretty (SimplifiedBlock n) where pretty (SimplifiedBlock block recon) = @@ -310,57 +374,70 @@ simpDeclsSubst simpDeclsSubst !s = \case Empty -> return s Nest (Let b (DeclBinding _ expr)) rest -> do - let hint = (getNameHint b) - x <- withSubst s $ simplifyExpr hint expr + x <- withSubst s $ simplifyExpr expr simpDeclsSubst (s <>> (b@>SubstVal x)) rest -simplifyExpr :: Emits o => NameHint -> Expr CoreIR i -> SimplifyM i o (CAtom o) -simplifyExpr hint expr = confuseGHC >>= \_ -> case expr of +simplifyExpr :: Emits o => Expr CoreIR i -> SimplifyM i o (CAtom o) +simplifyExpr expr = confuseGHC >>= \_ -> case expr of + Block _ (Abs decls body) -> simplifyDecls decls $ simplifyExpr body App (EffTy _ ty) f xs -> do ty' <- substM ty + f' <- forceConstructor f xs' <- mapM simplifyAtom xs - simplifyApp hint ty' f xs' - TabApp _ f xs -> do - xs' <- mapM simplifyAtom xs - f' <- simplifyAtom f - simplifyTabApp f' xs' + simplifyApp ty' f' xs' + TabApp _ f x -> withDistinct do + x' <- simplifyAtom x + f' <- forceConstructor f + simplifyTabApp f' x' Atom x -> simplifyAtom x - PrimOp op -> simplifyOp hint op + PrimOp op -> simplifyOp op ApplyMethod (EffTy _ ty) dict i xs -> do ty' <- substM ty xs' <- mapM simplifyAtom xs - dict' <- simplifyAtom dict + Just dict' <- toMaybeDict <$> simplifyAtom dict applyDictMethod ty' dict' i xs' TabCon _ ty xs -> do ty' <- substM ty - tySimp <- getRepType ty' - xs' <- forM xs \x -> simplifyDataAtom x - liftSimpAtom ty' =<< emitExpr (TabCon Nothing tySimp xs') + tySimp <- getRepType ty + xs' <- forM xs \x -> toDataAtom x + liftSimpAtom ty' =<< emit (TabCon Nothing tySimp xs') Case scrut alts (EffTy _ resultTy) -> do scrut' <- simplifyAtom scrut resultTy' <- substM resultTy defuncCaseCore scrut' resultTy' \i x -> do Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) $ simplifyBlock body + extendSubst (b@>SubstVal x) $ simplifyExpr body + Project ty i x -> do + ty' <- substM ty + x' <- substM x + tryAsDataAtom x' >>= \case + Just (x'', _) -> liftSimpAtom ty' =<< proj i x'' + Nothing -> requireReduced $ Project ty' i x' + Unwrap _ _ -> requireReduced =<< substM expr + +requireReduced :: CExpr o -> SimplifyM i o (CAtom o) +requireReduced expr = reduceExpr expr >>= \case + Just x -> return x + Nothing -> error "couldn't reduce expression" simplifyRefOp :: Emits o => RefOp CoreIR i -> SAtom o -> SimplifyM i o (SAtom o) simplifyRefOp op ref = case op of MExtend (BaseMonoid em cb) x -> do - em' <- simplifyDataAtom em - x' <- simplifyDataAtom x + em' <- toDataAtom em + x' <- toDataAtom x (cb', CoerceReconAbs) <- simplifyLam cb emitRefOp $ MExtend (BaseMonoid em' cb') x' - MGet -> emitOp $ RefOp ref MGet + MGet -> emit $ RefOp ref MGet MPut x -> do - x' <- simplifyDataAtom x + x' <- toDataAtom x emitRefOp $ MPut x' MAsk -> emitRefOp MAsk IndexRef _ x -> do - x' <- simplifyDataAtom x - emitOp =<< mkIndexRef ref x' - ProjRef _ (ProjectProduct i) -> emitOp =<< mkProjRef ref (ProjectProduct i) + x' <- toDataAtom x + emit =<< mkIndexRef ref x' + ProjRef _ (ProjectProduct i) -> emit =<< mkProjRef ref (ProjectProduct i) ProjRef _ UnwrapNewtype -> return ref - where emitRefOp op' = emitOp $ RefOp ref op' + where emitRefOp op' = emit $ RefOp ref op' defuncCaseCore :: Emits o => Atom CoreIR o -> Type CoreIR o @@ -374,49 +451,41 @@ defuncCaseCore scrut resultTy cont = do let xCoreTy = altBinderTys !! i x' <- liftSimpAtom (sink xCoreTy) x cont i x' - Nothing -> case trySelectBranch scrut of - Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg - Nothing -> go scrut where - go = \case - SimpInCore (ACase scrutSimp alts _) -> do - defuncCase scrutSimp resultTy \i x -> do - Abs altb altAtom <- return $ alts !! i - altAtom' <- applySubst (altb @> SubstVal x) altAtom - cont i altAtom' - NewtypeCon con scrut' | isSumCon con -> go scrut' - _ -> nope - nope = error $ "Don't know how to scrutinize non-data " ++ pprint scrut + Nothing -> case scrut of + Con (SumCon _ i arg) -> getDistinct >>= \Distinct -> cont i arg + _ -> error $ "Don't know how to scrutinize non-data " ++ pprint scrut defuncCase :: Emits o => Atom SimpIR o -> Type CoreIR o -> (forall o'. (Emits o', DExt o o') => Int -> SAtom o' -> SimplifyM i o' (CAtom o')) -> SimplifyM i o (CAtom o) defuncCase scrut resultTy cont = do - case trySelectBranch scrut of - Just (i, arg) -> getDistinct >>= \Distinct -> cont i arg - Nothing -> do - scrutTy <- return $ getType scrut - altBinderTys <- caseAltsBinderTys scrutTy + case scrut of + Con (SumCon _ i arg) -> getDistinct >>= \Distinct -> cont i arg + Con _ -> error "scrutinee must be a sum type" + Stuck _ _ -> do + altBinderTys <- caseAltsBinderTys (getType scrut) tryGetRepType resultTy >>= \case Just resultTyData -> do alts' <- forM (enumerate altBinderTys) \(i, bTy) -> do - buildAbs noHint bTy \x -> do - buildBlock $ cont i (sink $ Var x) >>= toDataAtomIgnoreRecon + buildAbs noHint bTy \x -> buildBlock do + ans <- cont i (toAtom $ sink x) + dropSubst $ toDataAtom ans caseExpr <- mkCase scrut resultTyData alts' - emitExpr caseExpr >>= liftSimpAtom resultTy + emit caseExpr >>= liftSimpAtom resultTy Nothing -> do split <- splitDataComponents resultTy (alts', closureTys, recons) <- unzip3 <$> forM (enumerate altBinderTys) \(i, bTy) -> do simplifyAlt split bTy $ cont i - let closureSumTy = SumTy closureTys + let closureSumTy = TyCon $ SumType closureTys let newNonDataTy = nonDataTy split alts'' <- forM (enumerate alts') \(i, alt) -> injectAltResult closureTys i alt caseExpr <- mkCase scrut (PairTy (dataTy split) closureSumTy) alts'' - caseResult <- emitExpr $ caseExpr + caseResult <- emit $ caseExpr (dataVal, sumVal) <- fromPair caseResult reconAlts <- forM (zip closureTys recons) \(ty, recon) -> - buildAbs noHint ty \v -> applyRecon (sink recon) (Var v) - let nonDataVal = SimpInCore $ ACase sumVal reconAlts newNonDataTy + buildAbs noHint ty \v -> applyRecon (sink recon) (toAtom v) + nonDataVal <- reduceACase sumVal reconAlts newNonDataTy Distinct <- getDistinct fromSplit split dataVal nonDataVal @@ -427,7 +496,7 @@ simplifyAlt -> SimplifyM i o (Alt SimpIR o, SType o, ReconstructAtom o) simplifyAlt split ty cont = do withFreshBinder noHint ty \b -> do - ab <- buildScoped $ cont $ sink $ Var $ binderVar b + ab <- buildScoped $ cont $ sink $ toAtom $ binderVar b (body, recon) <- refreshAbs ab \decls result -> do let locals = toScopeFrag b >>> toScopeFrag decls -- TODO: this might be too cautious. The type only needs to @@ -437,78 +506,39 @@ simplifyAlt split ty cont = do (resultData, resultNonData) <- toSplit split result (newResult, reconAbs) <- telescopicCapture locals resultNonData return (Abs decls (PairVal resultData newResult), LamRecon reconAbs) - EffTy _ (PairTy _ nonDataType) <- blockEffTy body + body' <- mkBlock body + PairTy _ nonDataType <- return $ getType body' let nonDataType' = ignoreHoistFailure $ hoist b nonDataType - return (Abs b body, nonDataType', recon) - -simplifyApp :: forall i o. Emits o - => NameHint -> CType o -> CAtom i -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyApp hint resultTy f xs = case f of - Lam (CoreLamExpr _ lam) -> fast lam - _ -> slow =<< simplifyAtomAndInline f - where - fast :: LamExpr CoreIR i' -> SimplifyM i' o (CAtom o) - fast lam = withInstantiated lam xs \body -> simplifyBlock body - - slow :: CAtom o -> SimplifyM i o (CAtom o) - slow = \case - Lam (CoreLamExpr _ lam) -> dropSubst $ fast lam - SimpInCore (ACase e alts _) -> dropSubst do - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - simplifyApp hint (sink resultTy) body xs' - SimpInCore (LiftSimpFun _ lam) -> do - xs' <- mapM toDataAtomIgnoreRecon xs - result <- instantiate lam xs' >>= emitBlock - liftSimpAtom resultTy result - Var v -> do - lookupAtomName (atomVarName v) >>= \case - NoinlineFun _ _ -> simplifyTopFunApp v xs - FFIFunBound _ f' -> do - xs' <- mapM toDataAtomIgnoreRecon xs - liftSimpAtom resultTy =<< naryTopApp f' xs' - b -> error $ "Should only have noinline functions left " ++ pprint b - atom -> error $ "Unexpected function: " ++ pprint atom - --- | Like `simplifyAtom`, but will try to inline function definitions found --- in the environment. The only exception is when we're going to differentiate --- and the function has a custom derivative rule defined. --- TODO(dougalm): do we still need this? -simplifyAtomAndInline :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtomAndInline atom = confuseGHC >>= \_ -> case atom of - Var v -> do - env <- getSubst - case env ! atomVarName v of - Rename v' -> doInline =<< toAtomVar v' - SubstVal (Var v') -> doInline v' - SubstVal x -> return x - -- This is a hack because we weren't normalize the unwrapping of - -- `unit_type_scale` in `plot.dx`. We need a better system for deciding how to - -- normalize and inline. - ProjectElt _ i x -> do - x' <- simplifyAtom x >>= normalizeProj i - dropSubst $ simplifyAtomAndInline x' - _ -> simplifyAtom atom >>= \case - Var v -> doInline v - ans -> return ans - where - doInline v = do - lookupAtomName (atomVarName v) >>= \case - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtomAndInline x - _ -> return $ Var v + return (Abs b body', nonDataType', recon) + +simplifyApp :: Emits o => CType o -> ConcreteCAtom o -> [CAtom o] -> SimplifyM i o (CAtom o) +simplifyApp resultTy f xs = case f of + CCCon (WithSubst s con) -> case con of + Lam (CoreLamExpr _ lam) -> withSubst s $ withInstantiated lam xs \body -> simplifyExpr body + _ -> error "not a function" + CCFun ccFun -> case ccFun of + CCLiftSimpFun _ lam -> do + xs' <- dropSubst $ mapM toDataAtom xs + result <- instantiate lam xs' >>= emit + liftSimpAtom resultTy result + CCNoInlineFun v _ _ -> simplifyTopFunApp v xs + CCFFIFun _ f' -> do + xs' <- dropSubst $ mapM toDataAtom xs + liftSimpAtom resultTy =<< naryTopApp f' xs' + CCACase aCase -> forceACase aCase \f' -> simplifyApp (sink resultTy) f' (sink <$> xs) + CCTabLam _ -> error "not a function" + CCLiftSimp _ _ -> error "not a function" simplifyTopFunApp :: Emits n => CAtomVar n -> [CAtom n] -> SimplifyM i n (CAtom n) simplifyTopFunApp fName xs = do - fTy@(Pi piTy) <- return $ getType fName + fTy@(TyCon (Pi piTy)) <- return $ getType fName resultTy <- typeOfApp fTy xs isData resultTy >>= \case True -> do (xsGeneralized, runtimeArgs) <- generalizeArgs piTy xs let spec = AppSpecialization fName xsGeneralized Just specializedFunction <- getSpecializedFunction spec >>= emitHoistedEnv - runtimeArgs' <- mapM toDataAtomIgnoreRecon runtimeArgs + runtimeArgs' <- dropSubst $ mapM toDataAtom runtimeArgs liftSimpAtom resultTy =<< naryTopApp specializedFunction runtimeArgs' False -> -- TODO: we should probably just fall back to inlining in this case, @@ -540,50 +570,38 @@ specializedFunCoreDefinition (AppSpecialization f (Abs bs staticArgs)) = do ListE staticArgs' <- applyRename (bs@@>(atomVarName <$> runtimeArgs)) staticArgs naryApp f' staticArgs' -simplifyTabApp :: forall i o. Emits o - => CAtom o -> [CAtom o] -> SimplifyM i o (CAtom o) -simplifyTabApp f [] = return f -simplifyTabApp f@(SimpInCore sic) xs = case sic of - TabLam _ _ -> do - case fromNaryTabLam (length xs) f of - Just (bsCount, ab) -> do - let (xsPref, xsRest) = splitAt bsCount xs - xsPref' <- mapM toDataAtomIgnoreRecon xsPref - block' <- instantiate ab xsPref' - atom <- emitDecls block' - simplifyTabApp atom xsRest - Nothing -> error "should never happen" - ACase e alts ty -> dropSubst do - resultTy <- typeOfTabApp ty xs - defuncCase e resultTy \i x -> do - Abs b body <- return $ alts !! i - extendSubst (b@>SubstVal x) do - xs' <- mapM sinkM xs - body' <- substM body - simplifyTabApp body' xs' - LiftSimp _ f' -> do - fTy <- return $ getType f - resultTy <- typeOfTabApp fTy xs - xs' <- mapM toDataAtomIgnoreRecon xs - liftSimpAtom resultTy =<< naryTabApp f' xs' - LiftSimpFun _ _ -> error "not implemented" -simplifyTabApp f _ = error $ "Unexpected table: " ++ pprint f - -simplifyIxType :: IxType CoreIR o -> SimplifyM i o (IxType SimpIR o) -simplifyIxType (IxType t ixDict) = do - t' <- getRepType t - IxType t' <$> case ixDict of - IxDictAtom (DictCon _ (IxFin n)) -> do - n' <- toDataAtomIgnoreReconAssumeNoDecls n - return $ IxDictRawFin n' - IxDictAtom d -> do - (dictAbs, params) <- generalizeIxDict =<< cheapNormalize d - params' <- mapM toDataAtomIgnoreReconAssumeNoDecls params - sdName <- requireIxDictCache dictAbs - return $ IxDictSpecialized t' sdName params' - IxDictRawFin n -> do - n' <- toDataAtomIgnoreReconAssumeNoDecls n - return $ IxDictRawFin n' +simplifyTabApp ::Emits o => ConcreteCAtom o -> CAtom o -> SimplifyM i o (CAtom o) +simplifyTabApp f x = case f of + CCLiftSimp fTy f' -> do + f'' <- mkStuck f' + resultTy <- typeOfTabApp fTy x + x' <- dropSubst $ toDataAtom x + liftSimpAtom resultTy =<< tabApp f'' x' + CCACase aCase -> forceACase aCase \f' -> simplifyTabApp f' (sink x) + CCTabLam (WithSubst s (PairE _ (Abs b ab))) -> do + x' <- dropSubst $ toDataAtom x + withSubst s $ extendSubst (b@>(SubstVal x')) $ substM ab + _ -> error "not a table" + +simplifyIxDict :: Dict CoreIR i -> SimplifyM i o (SDict o) +simplifyIxDict (StuckDict _ stuck) = forceStuck stuck >>= \case + CCCon (WithSubst s con) -> case con of + DictConAtom con' -> withSubst s $ simplifyIxDict (DictCon con') + _ -> error "not a dict" + CCLiftSimp _ _ -> error "not a dict" + CCFun _ -> error "not a dict" + CCTabLam _ -> error "not a dict" + CCACase _ -> error "not implemented" -- TODO: consider what to do about this +simplifyIxDict (DictCon con) = case con of + IxFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n + IxRawFin n -> DictCon <$> IxRawFin <$> toDataAtomAssumeNoDecls n + InstanceDict _ _ _ -> do + d <- DictCon <$> substM con + (dictAbs, params) <- generalizeIxDict d + params' <- dropSubst $ mapM toDataAtomAssumeNoDecls params + sdName <- requireIxDictCache dictAbs + return $ DictCon $ IxSpecialized sdName params' + DataData _ -> error "not an Ix dict" requireIxDictCache :: (HoistingTopBuilder TopEnvFrag m) => AbsDict n -> m n (Name SpecializedDictNameC n) @@ -609,7 +627,7 @@ simplifyDictMethod absDict@(Abs bs dict) method = do lamExpr <- liftBuilder $ buildTopLamFromPi ty \allArgs -> do let (extraArgs, methodArgs) = splitAt (nestLength bs) allArgs dict' <- applyRename (bs @@> (atomVarName <$> extraArgs)) dict - emitExpr =<< mkApplyMethod dict' (fromEnum method) (Var <$> methodArgs) + emit =<< mkApplyMethod dict' (fromEnum method) (toAtom <$> methodArgs) simplifyTopFunction lamExpr ixMethodType :: IxMethod -> AbsDict n -> EnvReaderM n (PiType CoreIR n) @@ -619,52 +637,19 @@ ixMethodType method absDict = do let allBs = extraArgBs >>> methodArgs return $ PiType allBs (EffTy Pure resultTy) --- TODO: do we even need this, or is it just a glorified `SubstM`? simplifyAtom :: CAtom i -> SimplifyM i o (CAtom o) -simplifyAtom atom = confuseGHC >>= \_ -> case atom of - Var v -> simplifyVar v - Lam _ -> substM atom - DepPair x y ty -> DepPair <$> simplifyAtom x <*> simplifyAtom y <*> substM ty - Con con -> Con <$> traverseOp con substM simplifyAtom (error "unexpected lambda") - Eff eff -> Eff <$> substM eff - PtrVar t v -> PtrVar t <$> substM v - DictCon t d -> (DictCon <$> substM t <*> substM d) >>= cheapNormalize - DictHole _ _ _ -> error "shouldn't have dict holes past inference" - NewtypeCon _ _ -> substM atom - ProjectElt _ i x -> normalizeProj i =<< simplifyAtom x - SimpInCore _ -> substM atom - TypeAsAtom _ -> substM atom - -simplifyVar :: AtomVar CoreIR i -> SimplifyM i o (CAtom o) -simplifyVar v = do - env <- getSubst - case env ! atomVarName v of - SubstVal x -> return x - Rename v' -> do - AtomNameBinding bindingInfo <- lookupEnv v' - let ty = getType bindingInfo - case bindingInfo of - -- Functions get inlined only at application sites - LetBound (DeclBinding _ _) | isFun -> return $ Var $ AtomVar v' ty - where isFun = case ty of Pi _ -> True; _ -> False - LetBound (DeclBinding _ (Atom x)) -> dropSubst $ simplifyAtom x - _ -> return $ Var $ AtomVar v' ty +simplifyAtom = substM -- Assumes first order (args/results are "data", allowing newtypes), monormophic simplifyLam :: LamExpr CoreIR i -> SimplifyM i o (LamExpr SimpIR o, Abs (Nest (AtomNameBinder SimpIR)) ReconstructAtom o) simplifyLam (LamExpr bsTop body) = case bsTop of - Nest (b:>ty) bs -> do - ty' <- substM ty - tySimp <- getRepType ty' - withFreshBinder (getNameHint b) tySimp \b''@(b':>_) -> do - x <- liftSimpAtom (sink ty') (Var $ binderVar b'') - extendSubst (b@>SubstVal x) do - (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body - return (LamExpr (Nest (b':>tySimp) bs') body', Abs (Nest b' bsRecon) recon) + Nest b bs -> withSimplifiedBinder b \b'@(b'':>_) -> do + (LamExpr bs' body', Abs bsRecon recon) <- simplifyLam $ LamExpr bs body + return (LamExpr (Nest b' bs') body', Abs (Nest b'' bsRecon) recon) Empty -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body return (LamExpr Empty body', Abs Empty recon) data SplitDataNonData n = SplitDataNonData @@ -676,25 +661,25 @@ data SplitDataNonData n = SplitDataNonData -- bijection between that type and a (data, non-data) pair type. splitDataComponents :: Type CoreIR n -> SimplifyM i n (SplitDataNonData n) splitDataComponents = \case - ProdTy tys -> do + TyCon (ProdType tys) -> do splits <- mapM splitDataComponents tys return $ SplitDataNonData - { dataTy = ProdTy $ map dataTy splits - , nonDataTy = ProdTy $ map nonDataTy splits + { dataTy = TyCon $ ProdType $ map dataTy splits + , nonDataTy = TyCon $ ProdType $ map nonDataTy splits , toSplit = \xProd -> do - xs <- getUnpacked xProd + xs <- getUnpackedReduced xProd (ys, zs) <- unzip <$> forM (zip xs splits) \(x, split) -> toSplit split x - return (ProdVal ys, ProdVal zs) + return (Con $ ProdCon ys, Con $ ProdCon zs) , fromSplit = \xsProd ysProd -> do - xs <- getUnpacked xsProd - ys <- getUnpacked ysProd + xs <- getUnpackedReduced xsProd + ys <- getUnpackedReduced ysProd zs <- forM (zip (zip xs ys) splits) \((x, y), split) -> fromSplit split x y - return $ ProdVal zs } + return $ Con $ ProdCon zs } ty -> tryGetRepType ty >>= \case Just repTy -> return $ SplitDataNonData { dataTy = repTy , nonDataTy = UnitTy - , toSplit = \x -> (,UnitVal) <$> toDataAtomIgnoreReconAssumeNoDecls x + , toSplit = \x -> (,UnitVal) <$> (dropSubst $ toDataAtomAssumeNoDecls x) , fromSplit = \x _ -> liftSimpAtom (sink ty) x } Nothing -> return $ SplitDataNonData { dataTy = UnitTy @@ -715,47 +700,37 @@ buildSimplifiedBlock cont = do return $ RightE (dataResult `PairE` ansTy) case eitherResult of LeftE ans -> do - (block, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do + (blockAbs, recon) <- refreshAbs (Abs decls ans) \decls' ans' -> do (newResult, reconAbs) <- telescopicCapture (toScopeFrag decls') ans' return (Abs decls' newResult, LamRecon reconAbs) - return $ SimplifiedBlock block recon + block' <- mkBlock blockAbs + return $ SimplifiedBlock block' recon RightE (ans `PairE` ty) -> do let ty' = ignoreHoistFailure $ hoist (toScopeFrag decls) ty - return $ SimplifiedBlock (Abs decls ans) (CoerceRecon ty') + block <- mkBlock $ Abs decls ans + return $ SimplifiedBlock block (CoerceRecon ty') -simplifyOp :: Emits o => NameHint -> PrimOp CoreIR i -> SimplifyM i o (CAtom o) -simplifyOp hint op = case op of +simplifyOp :: Emits o => PrimOp CoreIR i -> SimplifyM i o (CAtom o) +simplifyOp op = case op of Hof (TypedHof (EffTy _ ty) hof) -> do ty' <- substM ty - simplifyHof hint ty' hof + simplifyHof ty' hof MemOp op' -> simplifyGenericOp op' VectorOp op' -> simplifyGenericOp op' RefOp ref eff -> do - ref' <- simplifyDataAtom ref + ref' <- toDataAtom ref liftResult =<< simplifyRefOp eff ref' - BinOp binop x' y' -> do - x <- simplifyDataAtom x' - y <- simplifyDataAtom y' - liftResult =<< case binop of - ISub -> isub x y - IAdd -> iadd x y - IMul -> imul x y - IDiv -> idiv x y - ICmp Less -> ilt x y - ICmp Equal -> ieq x y - _ -> emitOp $ BinOp binop x y - UnOp unOp x' -> do - x <- simplifyDataAtom x' - liftResult =<< emitOp (UnOp unOp x) + BinOp binop x y -> do + x' <- toDataAtom x + y' <- toDataAtom y + liftResult =<< emit (BinOp binop x' y') + UnOp unOp x -> do + x' <- toDataAtom x + liftResult =<< emit (UnOp unOp x') MiscOp op' -> case op' of - Select c' x' y' -> do - c <- simplifyDataAtom c' - x <- simplifyDataAtom x' - y <- simplifyDataAtom y' - liftResult =<< select c x y - ShowAny x' -> do - x <- simplifyAtom x' - dropSubst $ showAny x >>= simplifyBlock + ShowAny x -> do + x' <- simplifyAtom x + dropSubst $ showAny x' >>= simplifyExpr _ -> simplifyGenericOp op' where liftResult x = do @@ -763,66 +738,61 @@ simplifyOp hint op = case op of liftSimpAtom ty x simplifyGenericOp - :: (GenericOp op, IsPrimOp op, HasType CoreIR (op CoreIR), Emits o, + :: (GenericOp op, ToExpr (op SimpIR) SimpIR, HasType CoreIR (op CoreIR), Emits o, OpConst op CoreIR ~ OpConst op SimpIR) => op CoreIR i -> SimplifyM i o (CAtom o) simplifyGenericOp op = do ty <- substM $ getType op - op' <- traverseOp op - (substM >=> getRepType) - (simplifyAtom >=> toDataAtomIgnoreRecon) - (error "shouldn't have lambda left") - result <- liftEnvReaderM (peepholeOp $ toPrimOp op') >>= emitExprToAtom - liftSimpAtom ty result + op' <- traverseOp op getRepType toDataAtom (error "shouldn't have lambda left") + liftSimpAtom ty =<< emit op' {-# INLINE simplifyGenericOp #-} pattern CoerceReconAbs :: Abs (Nest b) ReconstructAtom n pattern CoerceReconAbs <- Abs _ (CoerceRecon _) -applyDictMethod :: Emits o => CType o -> CAtom o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) -applyDictMethod resultTy d i methodArgs = do - cheapNormalize d >>= \case - DictCon _ (InstanceDict instanceName instanceArgs) -> dropSubst do - instanceArgs' <- mapM simplifyAtom instanceArgs - instanceDef <- lookupInstanceDef instanceName - withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do - let InstanceBody _ methods = body - let method = methods !! i - simplifyApp noHint resultTy method methodArgs - DictCon _ (IxFin n) -> applyIxFinMethod (toEnum i) n methodArgs - d' -> error $ "Not a simplified dict: " ++ pprint d' +applyDictMethod :: Emits o => CType o -> CDict o -> Int -> [CAtom o] -> SimplifyM i o (CAtom o) +applyDictMethod resultTy d i methodArgs = case d of + DictCon (InstanceDict _ instanceName instanceArgs) -> dropSubst do + instanceArgs' <- mapM simplifyAtom instanceArgs + instanceDef <- lookupInstanceDef instanceName + withInstantiated instanceDef instanceArgs' \(PairE _ body) -> do + let InstanceBody _ methods = body + let method = methods !! i + method' <- forceConstructor method + simplifyApp resultTy method' methodArgs + DictCon (IxFin n) -> applyIxFinMethod (toEnum i) n methodArgs + d' -> error $ "Not a simplified dict: " ++ pprint d' where applyIxFinMethod :: EnvReader m => IxMethod -> CAtom n -> [CAtom n] -> m n (CAtom n) applyIxFinMethod method n args = do case (method, args) of (Size, []) -> return n -- result : Nat - (Ordinal, [ix]) -> unwrapNewtype ix -- result : Nat - (UnsafeFromOrdinal, [ix]) -> return $ NewtypeCon (FinCon n) ix + (Ordinal, [ix]) -> reduceUnwrap ix -- result : Nat + (UnsafeFromOrdinal, [ix]) -> return $ toAtom $ NewtypeCon (FinCon n) ix _ -> error "bad ix args" -simplifyHof :: Emits o => NameHint -> CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o) -simplifyHof _hint resultTy = \case - For d ixTypeCore' lam -> do +simplifyHof :: Emits o => CType o -> Hof CoreIR i -> SimplifyM i o (CAtom o) +simplifyHof resultTy = \case + For d (IxType ixTy ixDict) lam -> do (lam', Abs (UnaryNest bIx) recon) <- simplifyLam lam - ixTypeCore <- substM ixTypeCore' - ixTypeSimp <- simplifyIxType ixTypeCore - ans <- emitHof $ For d ixTypeSimp lam' + ixTy' <- getRepType ixTy + ixDict' <- simplifyIxDict ixDict + ans <- emitHof $ For d (IxType ixTy' ixDict') lam' case recon of CoerceRecon _ -> liftSimpAtom resultTy ans LamRecon (Abs bsClosure reconResult) -> do - TabPi resultTabTy <- return resultTy - liftM (SimpInCore . TabLam resultTabTy) $ - PairE ixTypeSimp <$> buildAbs noHint (ixTypeType ixTypeSimp) \i -> buildScoped do - i' <- sinkM i - xs <- unpackTelescope bsClosure =<< tabApp (sink ans) (Var i') - applySubst (bIx@>Rename (atomVarName i') <.> bsClosure @@> map SubstVal xs) reconResult + ab <- buildAbs noHint ixTy' \i -> do + xs <- unpackTelescope bsClosure =<< reduceTabApp (sink ans) (toAtom i) + applySubst (bIx@>Rename (atomVarName i) <.> bsClosure @@> map SubstVal xs) reconResult + TyCon (TabPi resultTy') <- return resultTy + mkStuck $ TabLam $ resultTy' `PairE` ab While body -> do - SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' (CoerceRecon _) <- buildSimplifiedBlock $ simplifyExpr body result <- emitHof $ While body' liftSimpAtom resultTy result RunReader r lam -> do - r' <- simplifyDataAtom r + r' <- toDataAtom r (lam', Abs b recon) <- simplifyLam lam ans <- emitHof $ RunReader r' lam' let recon' = ignoreHoistFailure $ hoist b recon @@ -830,7 +800,7 @@ simplifyHof _hint resultTy = \case RunWriter Nothing (BaseMonoid e combine) lam -> do LamExpr (BinaryNest h (_:>RefTy _ wTy)) _ <- return lam wTy' <- substM $ ignoreHoistFailure $ hoist h wTy - e' <- simplifyDataAtom e + e' <- toDataAtom e (combine', CoerceReconAbs) <- simplifyLam combine (lam', Abs b recon) <- simplifyLam lam (ans, w) <- fromPair =<< emitHof (RunWriter Nothing (BaseMonoid e' combine') lam') @@ -840,7 +810,8 @@ simplifyHof _hint resultTy = \case return $ PairVal ans' w' RunWriter _ _ _ -> error "Shouldn't see a RunWriter with a dest in Simplify" RunState Nothing s lam -> do - (s', sTy) <- toDataAtom =<< simplifyAtom s + s' <- toDataAtom s + sTy <- substM $ getType s (lam', Abs b recon) <- simplifyLam lam resultPair <- emitHof $ RunState Nothing s' lam' (ans, sOut) <- fromPair resultPair @@ -850,15 +821,15 @@ simplifyHof _hint resultTy = \case return $ PairVal ans' sOut' RunState _ _ _ -> error "Shouldn't see a RunState with a dest in Simplify" RunIO body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body ans <- emitHof $ RunIO body' applyRecon recon ans RunInit body -> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body ans <- emitHof $ RunInit body' applyRecon recon ans Linearize lam x -> do - x' <- simplifyDataAtom x + x' <- toDataAtom x -- XXX: we're ignoring the result type here, which only makes sense if we're -- dealing with functions on simple types. (lam', recon) <- simplifyLam lam @@ -870,15 +841,14 @@ simplifyHof _hint resultTy = \case return $ PairVal result' linFun' Transpose lam x -> do (lam', CoerceReconAbs) <- simplifyLam lam - x' <- simplifyDataAtom x + x' <- toDataAtom x result <- transpose lam' x' liftSimpAtom resultTy result CatchException _ body-> do - SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyBlock body - simplifiedResultTy <- blockTy body' + SimplifiedBlock body' recon <- buildSimplifiedBlock $ simplifyExpr body block <- liftBuilder $ runSubstReaderT idSubst $ buildBlock $ - exceptToMaybeBlock (sink simplifiedResultTy) body' - result <- emitBlock block + exceptToMaybeExpr body' + result <- emit block case recon of CoerceRecon ty -> do maybeTy <- makePreludeMaybeTy ty @@ -887,19 +857,18 @@ simplifyHof _hint resultTy = \case -- takes an internal SimpIR Maybe to a CoreIR "prelude Maybe" fmapMaybe - :: (EnvReader m, EnvExtender m) - => SAtom n -> (forall l. DExt n l => SAtom l -> m l (CAtom l)) - -> m n (CAtom n) + :: SAtom n -> (forall l. DExt n l => SAtom l -> SimplifyM i l (CAtom l)) + -> SimplifyM i n (CAtom n) fmapMaybe scrut f = do ~(MaybeTy justTy) <- return $ getType scrut (justAlt, resultJustTy) <- withFreshBinder noHint justTy \b -> do - result <- f (Var $ binderVar b) + result <- f (toAtom $ binderVar b) resultTy <- return $ ignoreHoistFailure $ hoist b (getType result) result' <- preludeJustVal result return (Abs b result', resultTy) nothingAlt <- buildAbs noHint UnitTy \_ -> preludeNothingVal $ sink resultJustTy resultMaybeTy <- makePreludeMaybeTy resultJustTy - return $ SimpInCore $ ACase scrut [nothingAlt, justAlt] resultMaybeTy + reduceACase scrut [nothingAlt, justAlt] resultMaybeTy -- This is wrong! The correct implementation is below. And yet there's some -- compensatory bug somewhere that means that the wrong answer works and the @@ -913,17 +882,18 @@ preludeJustVal x = return x preludeNothingVal :: EnvReader m => CType n -> m n (CAtom n) preludeNothingVal ty = do con <- preludeMaybeNewtypeCon ty - return $ NewtypeCon con (NothingAtom ty) + return $ Con $ NewtypeCon con (NothingAtom ty) preludeMaybeNewtypeCon :: EnvReader m => CType n -> m n (NewtypeCon n) preludeMaybeNewtypeCon ty = do ~(Just (UTyConVar tyConName)) <- lookupSourceMap "Maybe" TyConDef sn _ _ _ <- lookupTyCon tyConName - let params = TyConParams [Explicit] [Type ty] + let params = TyConParams [Explicit] [toAtom ty] return $ UserADTData sn tyConName params -simplifyBlock :: Emits o => Block CoreIR i -> SimplifyM i o (CAtom o) -simplifyBlock (Abs decls result) = simplifyDecls decls $ simplifyAtom result +liftSimpFun :: EnvReader m => Type CoreIR n -> LamExpr SimpIR n -> m n (CAtom n) +liftSimpFun (TyCon (Pi piTy)) f = mkStuck $ LiftSimpFun piTy f +liftSimpFun _ _ = error "not a pi type" -- === simplifying custom linearizations === @@ -974,18 +944,20 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do Abs runtimeBs' <$> buildScoped do ListE staticArgs' <- instantiate (sink $ Abs runtimeBs staticArgs) (sink <$> runtimeArgs) fCustom' <- sinkM fCustom + -- TODO: give a HasType instance to ConcreteCAtom resultTy <- typeOfApp (getType fCustom') staticArgs' - pairResult <- dropSubst $ simplifyApp noHint resultTy fCustom' staticArgs' - (primalResult, fLin) <- fromPair pairResult - primalResult' <- toDataAtomIgnoreRecon primalResult + fCustom'' <- dropSubst $ forceConstructor fCustom' + pairResult <- dropSubst $ simplifyApp resultTy fCustom'' staticArgs' + (primalResult, fLin) <- fromPairReduced pairResult + primalResult' <- dropSubst $ toDataAtom primalResult let explicitPrimalArgs = drop nImplicit staticArgs' allTangentTys <- forM explicitPrimalArgs \primalArg -> do - tangentType =<< getRepType (getType primalArg) + tangentType =<< dropSubst (getRepType (getType primalArg)) let actives' = drop (length actives - nExplicit) actives activeTangentTys <- catMaybes <$> forM (zip allTangentTys actives') \(t, active) -> return case active of True -> Just t; False -> Nothing - fLin' <- buildUnaryLamExpr "t" (ProdTy activeTangentTys) \activeTangentArg -> do - activeTangentArgs <- getUnpacked $ Var activeTangentArg + fLin' <- buildUnaryLamExpr "t" (toType $ ProdType activeTangentTys) \activeTangentArg -> do + activeTangentArgs <- getUnpacked $ toAtom activeTangentArg ListE allTangentTys' <- sinkM $ ListE allTangentTys tangentArgs <- buildTangentArgs zeros (zip allTangentTys' actives') activeTangentArgs -- TODO: we're throwing away core type information here. Once we @@ -994,12 +966,13 @@ simplifyCustomLinearization (Abs runtimeBs staticArgs) actives rule = do -- a custom linearization defined for a function on ADTs will -- not work. fLin' <- sinkM fLin - Pi (CorePiType _ _ bs _) <- return $ getType fLin' + TyCon (Pi (CorePiType _ _ bs _)) <- return $ getType fLin' let tangentCoreTys = fromNonDepNest bs tangentArgs' <- zipWithM liftSimpAtom tangentCoreTys tangentArgs resultTyTangent <- typeOfApp (getType fLin') tangentArgs' - tangentResult <- dropSubst $ simplifyApp noHint resultTyTangent fLin' tangentArgs' - toDataAtomIgnoreRecon tangentResult + fLin'' <- dropSubst $ forceConstructor fLin' + tangentResult <- dropSubst $ simplifyApp resultTyTangent fLin'' tangentArgs' + dropSubst $ toDataAtom tangentResult return $ PairE primalResult' fLin' PairE primalFun tangentFun <- defuncLinearized linearized primalFun' <- asTopLam primalFun @@ -1040,10 +1013,10 @@ defuncLinearized ab = liftBuilder $ refreshAbs ab \bs ab' -> do return $ Abs (Nest rB tBs') UnitE residualsTangentsBs' <- return $ ignoreHoistFailure $ hoist decls residualsTangentsBs return (Abs decls (PairVal primalResult residuals), reconAbs, residualsTangentsBs') - let primalFun = LamExpr bs declsAndResult + primalFun <- LamExpr bs <$> mkBlock declsAndResult LamExpr residualAndTangentBs tangentBody <- buildLamExpr residualsTangentsBs \(residuals:tangents) -> do - LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (Var residuals) - applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emitBlock + LamExpr tangentBs' body <- applyReconAbs (sink reconAbs) (toAtom residuals) + applyRename (tangentBs' @@> (atomVarName <$> tangents)) body >>= emit let tangentFun = LamExpr (bs >>> residualAndTangentBs) tangentBody return $ PairE primalFun tangentFun @@ -1053,7 +1026,7 @@ type HandlerM = SubstReaderT AtomSubstVal (BuilderM SimpIR) exceptToMaybeBlock :: Emits o => SType o -> SBlock i -> HandlerM i o (SAtom o) exceptToMaybeBlock ty (Abs Empty result) = do - result' <- substM result + result' <- exceptToMaybeExpr result return $ JustAtom ty result' exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalResult) = do maybeResult <- exceptToMaybeExpr rhs @@ -1068,24 +1041,24 @@ exceptToMaybeBlock resultTy (Abs (Nest (Let b (DeclBinding _ rhs)) decls) finalR exceptToMaybeExpr :: Emits o => SExpr i -> HandlerM i o (SAtom o) exceptToMaybeExpr expr = case expr of + Block (EffTy _ ty) body -> do + ty' <- substM ty + exceptToMaybeBlock ty' body Case e alts (EffTy _ resultTy) -> do e' <- substM e resultTy' <- substM $ MaybeTy resultTy buildCase e' resultTy' \i v -> do Abs b body <- return $ alts !! i extendSubst (b @> SubstVal v) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body Atom x -> do x' <- substM x let ty = getType x' return $ JustAtom ty x' PrimOp (Hof (TypedHof _ (For ann ixTy' (UnaryLamExpr b body)))) -> do ixTy <- substM ixTy' - maybes <- buildForAnn (getNameHint b) ann ixTy \i -> do - extendSubst (b@>Rename (atomVarName i)) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + maybes <- buildFor (getNameHint b) ann ixTy \i -> do + extendSubst (b@>Rename (atomVarName i)) $ exceptToMaybeExpr body catMaybesE maybes PrimOp (MiscOp (ThrowException _)) -> do ty <- substM $ getType expr @@ -1095,8 +1068,7 @@ exceptToMaybeExpr expr = case expr of BinaryLamExpr h ref body <- return lam result <- emitRunState noHint s' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body (maybeAns, newState) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) @@ -1107,16 +1079,13 @@ exceptToMaybeExpr expr = case expr of PairTy _ accumTy <- substM resultTy result <- emitRunWriter noHint accumTy monoid' \h' ref' -> extendSubst (h @> Rename (atomVarName h') <.> ref @> Rename (atomVarName ref')) do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - exceptToMaybeBlock blockResultTy body + exceptToMaybeExpr body (maybeAns, accumResult) <- fromPair result a <- substM $ getType expr emitMaybeCase maybeAns (MaybeTy a) (return $ NothingAtom $ sink a) (\ans -> return $ JustAtom (sink a) $ PairVal ans (sink accumResult)) - PrimOp (Hof (TypedHof _ (While body))) -> do - blockResultTy <- blockTy =<< substM body -- TODO: avoid this by caching the type - runMaybeWhile $ exceptToMaybeBlock (sink blockResultTy) body + PrimOp (Hof (TypedHof _ (While body))) -> runMaybeWhile $ exceptToMaybeExpr body _ -> do expr' <- substM expr case hasExceptions expr' of @@ -1124,7 +1093,7 @@ exceptToMaybeExpr expr = case expr of False -> do v <- emit expr' let ty = getType v - return $ JustAtom ty (Var v) + return $ JustAtom ty v hasExceptions :: SExpr n -> Bool hasExceptions expr = case getEffects expr of diff --git a/src/lib/Simplify.hs-boot b/src/lib/Simplify.hs-boot index c14ae648a..8e1499c3d 100644 --- a/src/lib/Simplify.hs-boot +++ b/src/lib/Simplify.hs-boot @@ -9,5 +9,6 @@ module Simplify (linearizeTopFun) where import Name import Builder import Types.Core +import Types.Top linearizeTopFun :: (Mut n, Fallible1 m, TopBuilder m) => LinearizationSpec n -> m n (TopFunName n, TopFunName n) diff --git a/src/lib/SourceIdTraversal.hs b/src/lib/SourceIdTraversal.hs new file mode 100644 index 000000000..7e2436200 --- /dev/null +++ b/src/lib/SourceIdTraversal.hs @@ -0,0 +1,117 @@ +-- Copyright 2023 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +module SourceIdTraversal (getGroupTree) where + +import Control.Monad.Writer.Strict +import Data.Functor ((<&>)) + +import Types.Source +import Types.Primitives +import Err + +getGroupTree :: SourceBlock' -> GroupTree +getGroupTree b = mkGroupTree False rootSrcId $ runTreeM $ visit b + +type TreeM = Writer [GroupTree] + +mkGroupTree :: Bool -> SrcId -> [GroupTree] -> GroupTree +mkGroupTree isAtomic sid = \case + [] -> GroupTree sid (sid,sid) [] isAtomic -- no children - must be a lexeme + subtrees -> GroupTree sid (l,r) subtrees isAtomic + where l = minimum $ subtrees <&> (fst . gtSpan) + r = maximum $ subtrees <&> (snd . gtSpan) + +runTreeM :: TreeM () -> [GroupTree] +runTreeM cont = snd $ runWriter $ cont + +enterNode :: SrcId -> TreeM () -> TreeM () +enterNode sid cont = tell [mkGroupTree False sid (runTreeM cont)] + +emitLexeme :: SrcId -> TreeM () +emitLexeme lexemeId = tell [mkGroupTree True lexemeId []] + +class IsTree a where + visit :: a -> TreeM () + +instance IsTree SourceBlock' where + visit = \case + TopDecl decl -> visit decl + Command _ g -> visit g + DeclareForeign v1 v2 g -> visit v1 >> visit v2 >> visit g + DeclareCustomLinearization v _ g -> visit v >> visit g + Misc _ -> return () + UnParseable _ _ -> return () + +instance IsTree Group where + visit = \case + CLeaf _ -> return () + CPrim _ xs -> mapM_ visit xs + CParens xs -> mapM_ visit xs + CBrackets xs -> mapM_ visit xs + CBin b l r -> visit l >> visit b >> visit r + CJuxtapose _ l r -> visit l >> visit r + CPrefix l r -> visit l >> visit r + CGivens (x,y) -> visit x >> visit y + CLambda args body -> visit args >> visit body + CFor _ args body -> visit args >> visit body + CCase scrut alts -> visit scrut >> visit alts + CIf scrut ifTrue ifFalse -> visit scrut >> visit ifTrue >> visit ifFalse + CDo body -> visit body + CArrow l effs r -> visit l >> visit effs >> visit r + CWith b body -> visit b >> visit body + +instance IsTree CSBlock where + visit = \case + IndentedBlock sid decls -> enterNode sid $ visit decls + ExprBlock body -> visit body + +instance IsTree CSDecl where + visit = \case + CLet v rhs -> visit v >> visit rhs + CDefDecl def -> visit def + CExpr g -> visit g + CBind v body -> visit v >> visit body + CPass -> return () + +instance IsTree CTopDecl where + visit = \case + CSDecl _ decl -> visit decl + CData v params givens cons -> visit v >> visit params >> visit givens >> visit cons + CStruct v params givens fields methods -> visit v >> visit params >> visit givens >> visit fields >> visit methods + CInterface v params methods -> visit v >> visit params >> visit methods + CInstanceDecl def -> visit def + +instance IsTree CDef where + visit (CDef v params rhs givens body) = + visit v >> visit params >> visit rhs >> visit givens >> visit body + +instance IsTree CInstanceDef where + visit (CInstanceDef v args givens methods name) = + visit v >> visit args >> visit givens >> visit methods >> visit name + +instance IsTree a => IsTree (WithSrc a) where + visit (WithSrc sid x) = enterNode sid $ visit x + +instance IsTree a => IsTree (WithSrcs a) where + visit (WithSrcs sid sids x) = enterNode sid $ mapM_ emitLexeme sids >> visit x + +instance IsTree a => IsTree [a] where + visit xs = mapM_ visit xs + +instance IsTree a => IsTree (Maybe a) where + visit xs = mapM_ visit xs + +instance (IsTree a, IsTree b) => IsTree (a, b) where + visit (x, y) = visit x >> visit y + +instance (IsTree a, IsTree b, IsTree c) => IsTree (a, b, c) where + visit (x, y, z) = visit x >> visit y >> visit z + +instance IsTree AppExplicitness where visit _ = return () +instance IsTree SourceName where visit _ = return () +instance IsTree LetAnn where visit _ = return () +instance IsTree Bin where visit _ = return () diff --git a/src/lib/SourceInfo.hs b/src/lib/SourceInfo.hs deleted file mode 100644 index 6079fe5ff..000000000 --- a/src/lib/SourceInfo.hs +++ /dev/null @@ -1,226 +0,0 @@ --- Copyright 2021 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module SourceInfo ( - SrcPos, SpanId, SrcPosCtx (..), emptySrcPosCtx, fromPos, - pattern EmptySrcPosCtx, - sliceText, SpanTree (..), SpanTreeM (..), SpanPayload, SpanPos, - evalSpanTree, makeSpanTree, makeEmptySpanTree, makeSpanTreeRec, - fixSpanPayloads, - fillTreeAndAddTrivialLeaves - ) where - -import Data.Data -import Data.Hashable -import Data.Char (isSpace) -import Data.List (findIndex) -import Data.Maybe (listToMaybe, maybeToList) -import Data.Store (Store (..)) -import qualified Data.Text as T -import GHC.Generics (Generic (..)) -import Control.Applicative -import Control.Monad.State.Strict - --- === Core API === - -type SrcPos = (Int, Int) -type SpanId = Int - -data SrcPosCtx = SrcPosCtx (Maybe SrcPos) (Maybe SpanId) - deriving (Show, Eq, Generic, Data) -instance Hashable SrcPosCtx -instance Store SrcPosCtx - -instance Ord SrcPosCtx where - compare (SrcPosCtx pos spanId) (SrcPosCtx pos' spanId') = - case (pos, pos') of - (Just (l, r), Just (l', r')) -> compare (l, r', spanId) (l', r, spanId') - (Just _, _) -> GT - (_, Just _) -> LT - (_, _) -> compare spanId spanId' - -emptySrcPosCtx :: SrcPosCtx -emptySrcPosCtx = SrcPosCtx Nothing Nothing - -pattern EmptySrcPosCtx :: SrcPosCtx -pattern EmptySrcPosCtx = SrcPosCtx Nothing Nothing - -fromPos :: SrcPos -> SrcPosCtx -fromPos pos = SrcPosCtx (Just pos) Nothing - --- === Span utilities === - -type SpanPayload = (Int, Int, SpanId) -type SpanPos = (Int, Int) - -data SpanTree = - Span SpanPayload [SpanTree] | - LeafSpan SpanPayload | - Trivia SpanPos - deriving (Show, Eq) - -newtype SpanTreeM a = SpanTreeM - { runSpanTree' :: StateT [SpanPayload] Maybe a } - deriving (Functor, Applicative, Monad, MonadState [SpanPayload], Alternative) - -evalSpanTree :: SpanTreeM a -> [SpanPayload] -> Maybe a -evalSpanTree m spans = evalStateT (runSpanTree' m) spans - -getNextSpanPayload :: SpanTreeM (Maybe SpanPayload) -getNextSpanPayload = SpanTreeM $ do - infos <- get - case infos of - [] -> return Nothing - x:xs -> put xs >> return (Just x) - -data SpanContained = Contained | NotContained | PartialOverlap - deriving (Show, Eq) - --- | @contained x y@ returns whether @y@ is contained in @x@. -spanContained :: SpanPayload -> SpanPayload -> SpanContained -spanContained (lpos, rpos, _) (lpos', rpos', _) = - case (lpos <= lpos', rpos >= rpos') of - (True, True) -> Contained - (False, False) -> NotContained - (_, _) -> if rpos <= lpos' - then NotContained - else PartialOverlap - --- | @makeSpanTreeRec x@ returns a @[SpanTree]@ with the children of @x@. -getSpanChildren :: SpanPayload -> SpanTreeM (Maybe [SpanTree]) -getSpanChildren root = do - getNextSpanPayload >>= \case - Just child -> do - case spanContained root child of - -- If `child` is contained in `root`, then we add it as a child. - Contained -> do - childTree <- makeSpanTreeRec child - remainingChildren <- getSpanChildren root - return $ Just (maybeToList childTree ++ concat (maybeToList remainingChildren)) - NotContained -> do infos <- get; put (child : infos); return $ Just [] - PartialOverlap -> do infos <- get; put (child : infos); return $ Just [] - Nothing -> return $ Just [] - --- | @makeSpanTreeRec x@ returns a @SpanTree@ with the @x@ as the root. -makeSpanTreeRec :: SpanPayload -> SpanTreeM (Maybe SpanTree) -makeSpanTreeRec root = do - children <- getSpanChildren root - case children of - Nothing -> return Nothing - Just [] -> return $ Just (LeafSpan root) - Just xs -> return $ Just (Span root xs) - -makeEmptySpanTree :: [SpanPayload] -> Maybe SpanTree -makeEmptySpanTree [] = Nothing -makeEmptySpanTree (root:children) = join $ evalSpanTree (makeSpanTreeRec root) children - -makeSpanTree :: (Show a, IsTrivia a) => [a] -> [SpanPayload] -> Maybe SpanTree -makeSpanTree xs infos = case makeEmptySpanTree infos of - Nothing -> Nothing - Just posTree -> Just (fillTreeAndAddTrivialLeaves xs posTree) - -slice :: Int -> Int -> [a] -> [a] -slice left right xs = take (right - left) (drop left xs) - -sliceText :: Int -> Int -> T.Text -> T.Text -sliceText left right xs = T.take (right - left) (T.drop left xs) - -getSpanPos :: SpanTree -> SpanPos -getSpanPos tree = case tree of - Span (l, r, _) _ -> (l, r) - LeafSpan (l, r, _) -> (l, r) - Trivia pos -> pos - -fillTrivia :: SpanPayload -> [SpanTree] -> [SpanTree] -fillTrivia (l, r, _) offsets = - let (before, after) = case offsets of - [] -> ([], []) - _ -> - let (headL, _) = getSpanPos (head offsets) in - let (_, tailR) = getSpanPos (last offsets) in - let before' = [Trivia (l, headL) | l /= headL] in - let after' = [Trivia (tailR, r) | r /= tailR] in - (before', after') in - let offsets' = before ++ offsets ++ after in - let pairs = zip offsets' (drop 1 offsets') in - let unzipped = pairs >>= getOffsetAndTrivia in - maybeToList (listToMaybe offsets') ++ unzipped - where getOffsetAndTrivia :: (SpanTree, SpanTree) -> [SpanTree] - getOffsetAndTrivia (t, t') = - let (_, r') = endpoints t in - let (l', _) = endpoints t' in - let diff = l' - r' in - if diff == 0 then - [t'] - else - [Trivia (r', l'), t'] - -fixSpanPayloads :: [SpanPayload] -> [SpanPayload] -fixSpanPayloads spans = - let pairs = zip spans (drop 1 spans) in - let unzipped = pairs >>= mergeSpans in - unzipped ++ [last spans] - where mergeSpans :: (SpanPayload, SpanPayload) -> [SpanPayload] - mergeSpans (s, s') = case spanContained s s' of - Contained -> [s] - NotContained -> [s] - -- Note: currently, overlapping spans are simply dropped. - -- Consider replacing with approach that preserves partial span info. - PartialOverlap -> [] - -rebalanceTrivia :: Show a => (a -> Bool) -> [a] -> [SpanTree] -> [SpanTree] -rebalanceTrivia trivia xs trees = - let whitespaceSeparated = trees >>= createTrivia in - whitespaceSeparated - where - createTrivia :: SpanTree -> [SpanTree] - createTrivia t = case t of - Span _ _ -> [t] - LeafSpan _ -> blah - Trivia _ -> blah - where blah :: [SpanTree] - blah = - let (l, r) = endpoints t in - let s' = slice l r xs in - let firstNonTrivia = findIndex (not . trivia) s' in - let lastNonTrivia = fmap (length s' -) (findIndex (not . trivia) (reverse s')) in - case (firstNonTrivia, lastNonTrivia) of - (Just l', Nothing) | l' > 0 -> [Trivia (l, l + l'), shiftTree (l + l', r) t] - (Nothing, Just r') | r' < length s' -> [shiftTree (l, l + r') t, Trivia (l + r', r)] - (Just l', Just r') | l' > 0 || r' < length s' -> - [Trivia (l, l + l'), shiftTree (l + l', l + r') t, Trivia (l + r', r)] - (_, _) -> [t] - - -- - shiftTree :: SpanPos -> SpanTree -> SpanTree - shiftTree (l', r') t = case t of - Span (_, _, i) children -> Span (l', r', i) children - LeafSpan (_, _, i) -> LeafSpan (l', r', i) - Trivia _ -> Trivia (l', r') - -endpoints :: SpanTree -> (Int, Int) -endpoints (Span (l, r, _) _) = (l, r) -endpoints (LeafSpan (l, r, _)) = (l, r) -endpoints (Trivia (l, r)) = (l, r) - -class IsTrivia a where - isTrivia :: a -> Bool - -instance IsTrivia Char where - isTrivia = isSpace - --- | Fills a @SpanTree@ with @Trivia@ in span gaps. -fillTreeAndAddTrivialLeaves :: Show a => IsTrivia a => [a] -> SpanTree -> SpanTree -fillTreeAndAddTrivialLeaves xs tree = case tree of - Span info children -> - let children' = fillTrivia info children in - let children'' = rebalanceTrivia isTrivia xs children' in - let filled = map (fillTreeAndAddTrivialLeaves xs) children'' in - Span info filled - LeafSpan _ -> tree - Trivia _ -> tree diff --git a/src/lib/SourceRename.hs b/src/lib/SourceRename.hs index 3ee3b13b1..d5420dab7 100644 --- a/src/lib/SourceRename.hs +++ b/src/lib/SourceRename.hs @@ -10,7 +10,6 @@ module SourceRename ( renameSourceNamesTopUDecl, uDeclErrSourceMap , renameSourceNamesUExpr ) where import Prelude hiding (id, (.)) -import Data.List (sort) import Control.Category import Control.Monad.Except hiding (Except) import qualified Data.Set as S @@ -19,11 +18,11 @@ import qualified Data.Map.Strict as M import Err import Name import Core (EnvReader (..), withEnv, lookupSourceMapPure) -import PPrint () +import PPrint import IRVariants import Types.Source import Types.Primitives -import Types.Core (Env (..), ModuleEnv (..)) +import Types.Top (Env (..), ModuleEnv (..)) renameSourceNamesTopUDecl :: (Fallible1 m, EnvReader m) @@ -60,7 +59,7 @@ data RenamerSubst n = RenamerSubst { renamerSourceMap :: SourceMap n , renamerMayShadow :: Bool } newtype RenamerM (n::S) (a:: *) = - RenamerM { runRenamerM :: OutReaderT RenamerSubst (ScopeReaderT FallibleM) n a } + RenamerM { runRenamerM :: OutReaderT RenamerSubst (ScopeReaderT Except) n a } deriving ( Functor, Applicative, Monad, MonadFail, Fallible , ScopeReader, ScopeExtender) @@ -68,7 +67,7 @@ liftRenamer :: (EnvReader m, Fallible1 m, SinkableE e) => RenamerM n (e n) -> m liftRenamer cont = do sm <- withEnv $ envSourceMap . moduleEnv Distinct <- getDistinct - (liftExcept =<<) $ liftM runFallibleM $ liftScopeReaderT $ + (liftExcept =<<) $ liftScopeReaderT $ runOutReaderT (RenamerSubst sm False) $ runRenamerM $ cont class ( Monad1 m, ScopeReader m @@ -99,60 +98,48 @@ class SourceRenamableB (b :: B) where -> m o a instance SourceRenamableE (SourceNameOr UVar) where - sourceRenameE (SourceName pos sourceName) = - InternalName pos sourceName <$> lookupSourceName sourceName + sourceRenameE (SourceName sid sourceName) = + InternalName sid sourceName <$> lookupSourceName sid sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" -lookupSourceName :: Renamer m => SourceName -> m n (UVar n) -lookupSourceName v = do +lookupSourceName :: Renamer m => SrcId -> SourceName -> m n (UVar n) +lookupSourceName sid v = do sm <- askSourceMap case lookupSourceMapPure sm v of - [] -> throw UnboundVarErr $ pprint v + [] -> throw sid $ UnboundVarErr $ pprint v LocalVar v' : _ -> return v' [ModuleVar _ maybeV] -> case maybeV of Just v' -> return v' - Nothing -> throw VarDefErr v - vs -> throw AmbiguousVarErr $ ambiguousVarErrMsg v vs - -ambiguousVarErrMsg :: SourceName -> [SourceNameDef n] -> String -ambiguousVarErrMsg v defs = - -- we sort the lines to make the result a bit more deterministic for quine tests - pprint v ++ " is defined:\n" ++ unlines (sort $ map defsPretty defs) - where - defsPretty :: SourceNameDef n -> String - defsPretty (ModuleVar mname _) = case mname of - Main -> "in this file" - Prelude -> "in the prelude" - OrdinaryModule mname' -> "in " ++ pprint mname' - defsPretty (LocalVar _) = - error "shouldn't be possible because module vars can't shadow local ones" + Nothing -> throw sid $ VarDefErr $ pprint v + vs -> throw sid $ AmbiguousVarErr (pprint v) (map wherePretty vs) + where + wherePretty :: SourceNameDef n -> String + wherePretty (ModuleVar mname _) = case mname of + Main -> "in this file" + Prelude -> "in the prelude" + OrdinaryModule mname' -> "in " ++ pprint mname' + wherePretty (LocalVar _) = + error "shouldn't be possible because module vars can't shadow local ones" instance SourceRenamableE (SourceNameOr (Name (AtomNameC CoreIR))) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UAtomVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not an ordinary variable: " ++ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UAtomVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotAnOrdinaryVar $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name DataConNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UDataConVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not a data constructor: " ++ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UDataConVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotADataCon $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name ClassNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UClassVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not a class name: " ++ pprint sourceName - sourceRenameE _ = error "Shouldn't be source-renaming internal names" - -instance SourceRenamableE (SourceNameOr (Name EffectNameC)) where - sourceRenameE (SourceName pos sourceName) = do - lookupSourceName sourceName >>= \case - UEffectVar v -> return $ InternalName pos sourceName v - _ -> throw TypeErr $ "Not an effect name: " ++ pprint sourceName + sourceRenameE (SourceName sid sourceName) = do + lookupSourceName sid sourceName >>= \case + UClassVar v -> return $ InternalName sid sourceName v + _ -> throw sid $ NotAClassName $ pprint sourceName sourceRenameE _ = error "Shouldn't be source-renaming internal names" instance SourceRenamableE (SourceNameOr (Name c)) => SourceRenamableE (SourceOrInternalName c) where @@ -164,25 +151,24 @@ instance (SourceRenamableE e, SourceRenamableB b) => SourceRenamableE (Abs b e) instance SourceRenamableB (UBinder (AtomNameC CoreIR)) where sourceRenameB b cont = sourceRenameUBinder UAtomVar b cont -instance SourceRenamableE (UAnn req) where +instance SourceRenamableE UAnn where sourceRenameE UNoAnn = return UNoAnn sourceRenameE (UAnn ann) = UAnn <$> sourceRenameE ann -instance SourceRenamableB (UAnnBinder req) where - sourceRenameB (UAnnBinder b ann cs) cont = do +instance SourceRenamableB UAnnBinder where + sourceRenameB (UAnnBinder expl b ann cs) cont = do ann' <- sourceRenameE ann - cs' <- mapM sourceRenameE cs - sourceRenameB b \b' -> - cont $ UAnnBinder b' ann' cs' + cs' <- mapM sourceRenameE cs + sourceRenameB b \b' -> cont $ UAnnBinder expl b' ann' cs' -instance SourceRenamableE UExpr' where - sourceRenameE expr = setMayShadow True case expr of +instance SourceRenamableE UExpr where + sourceRenameE (WithSrcE sid expr) = liftM (WithSrcE sid) $ setMayShadow True case expr of UVar v -> UVar <$> sourceRenameE v ULit l -> return $ ULit l ULam lam -> ULam <$> sourceRenameE lam - UPi (UPiExpr (attrs, pats) appExpl eff body) -> + UPi (UPiExpr pats appExpl eff body) -> sourceRenameB pats \pats' -> - UPi <$> (UPiExpr (attrs, pats') <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) + UPi <$> (UPiExpr pats' <$> pure appExpl <*> sourceRenameE eff <*> sourceRenameE body) UApp f xs ys -> UApp <$> sourceRenameE f <*> forM xs sourceRenameE <*> forM ys (\(name, y) -> (name,) <$> sourceRenameE y) @@ -225,14 +211,6 @@ instance SourceRenamableE UEffect where sourceRenameE UExceptionEffect = return UExceptionEffect sourceRenameE UIOEffect = return UIOEffect -instance SourceRenamableE a => SourceRenamableE (WithSrcE a) where - sourceRenameE (WithSrcE pos e) = addSrcContext pos $ - WithSrcE pos <$> sourceRenameE e - -instance SourceRenamableB a => SourceRenamableB (WithSrcB a) where - sourceRenameB (WithSrcB pos b) cont = addSrcContext pos $ - sourceRenameB b \b' -> cont $ WithSrcB pos b' - instance SourceRenamableB UTopDecl where sourceRenameB decl cont = case decl of ULocalDecl d -> sourceRenameB d \d' -> cont $ ULocalDecl d' @@ -245,50 +223,46 @@ instance SourceRenamableB UTopDecl where sourceRenameUBinder UPunVar tyConName \tyConName' -> do structDef' <- sourceRenameE structDef cont $ UStructDecl tyConName' structDef' - UInterface (attrs, paramBs) methodTys className methodNames -> do + UInterface paramBs methodTys className methodNames -> do Abs paramBs' (ListE methodTys') <- sourceRenameB paramBs \paramBs' -> do methodTys' <- mapM sourceRenameE methodTys return $ Abs paramBs' $ ListE methodTys' sourceRenameUBinder UClassVar className \className' -> sourceRenameUBinderNest UMethodVar methodNames \methodNames' -> - cont $ UInterface (attrs, paramBs') methodTys' className' methodNames' - UInstance className (roleExpls, conditions) params methodDefs instanceName expl -> do + cont $ UInterface paramBs' methodTys' className' methodNames' + UInstance className conditions params methodDefs instanceName expl -> do className' <- sourceRenameE className Abs conditions' (PairE (ListE params') (ListE methodDefs')) <- sourceRenameE $ Abs conditions (PairE (ListE params) $ ListE methodDefs) sourceRenameB instanceName \instanceName' -> - cont $ UInstance className' (roleExpls, conditions') params' methodDefs' instanceName' expl - UEffectDecl opTypes effName opNames -> do - opTypes' <- mapM (\(UEffectOpType p ty) -> (UEffectOpType p) <$> sourceRenameE ty) opTypes - sourceRenameUBinder UEffectVar effName \effName' -> - sourceRenameUBinderNest UEffectOpVar opNames \opNames' -> - cont $ UEffectDecl opTypes' effName' opNames' - UHandlerDecl _ _ _ _ _ _ _ -> error "not implemented" - -instance SourceRenamableB UDecl' where - sourceRenameB decl cont = case decl of + cont $ UInstance className' conditions' params' methodDefs' instanceName' expl + +instance SourceRenamableB UDecl where + sourceRenameB (WithSrcB sid decl) cont = case decl of ULet ann pat ty expr -> do expr' <- sourceRenameE expr ty' <- mapM sourceRenameE ty sourceRenameB pat \pat' -> - cont $ ULet ann pat' ty' expr' - UExprDecl e -> cont =<< (UExprDecl <$> sourceRenameE e) - UPass -> cont UPass + cont $ WithSrcB sid $ ULet ann pat' ty' expr' + UExprDecl e -> do + e' <- UExprDecl <$> sourceRenameE e + cont $ WithSrcB sid e' + UPass -> cont $ WithSrcB sid UPass instance SourceRenamableE ULamExpr where - sourceRenameE (ULamExpr (expls, args) expl effs resultTy body) = - sourceRenameB args \args' -> ULamExpr (expls, args') + sourceRenameE (ULamExpr args expl effs resultTy body) = + sourceRenameB args \args' -> ULamExpr args' <$> pure expl <*> mapM sourceRenameE effs <*> mapM sourceRenameE resultTy <*> sourceRenameE body -instance SourceRenamableE UBlock' where - sourceRenameE (UBlock decls result) = +instance SourceRenamableE UBlock where + sourceRenameE (WithSrcE sid (UBlock decls result)) = sourceRenameB decls \decls' -> do result' <- sourceRenameE result - return $ UBlock decls' result' + return $ WithSrcE sid $ UBlock decls' result' instance SourceRenamableB UnitB where sourceRenameB UnitB cont = cont UnitB @@ -316,35 +290,35 @@ sourceRenameUBinderNest asUVar (Nest b bs) cont = sourceRenameUBinderNest asUVar bs \bs' -> cont $ Nest b' bs' -sourceRenameUBinder :: (Color c, Distinct o, Renamer m) - => (forall l. Name c l -> UVar l) - -> UBinder c i i' - -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) - -> m o a -sourceRenameUBinder asUVar ubinder cont = case ubinder of - UBindSource pos b -> do +sourceRenameUBinder + :: (Color c, Distinct o, Renamer m) + => (forall l. Name c l -> UVar l) + -> UBinder c i i' + -> (forall o'. DExt o o' => UBinder c o o' -> m o' a) + -> m o a +sourceRenameUBinder asUVar (WithSrcB sid ubinder) cont = case ubinder of + UBindSource b -> do SourceMap sm <- askSourceMap mayShadow <- askMayShadow let shadows = M.member b sm - when (not mayShadow && shadows) $ - throw RepeatedVarErr $ pprint b + when (not mayShadow && shadows) $ throw sid $ RepeatedVarErr $ pprint b withFreshM (getNameHint b) \freshName -> do Distinct <- getDistinct extendSourceMap b (asUVar $ binderName freshName) $ - cont $ UBind pos b freshName - UBind _ _ _ -> error "Shouldn't be source-renaming internal names" - UIgnore -> cont UIgnore + cont $ WithSrcB sid $ UBind b freshName + UBind _ _ -> error "Shouldn't be source-renaming internal names" + UIgnore -> cont $ WithSrcB sid $ UIgnore instance SourceRenamableE UDataDef where - sourceRenameE (UDataDef tyConName (expls, paramBs) dataCons) = do + sourceRenameE (UDataDef tyConName paramBs dataCons) = do sourceRenameB paramBs \paramBs' -> do dataCons' <- forM dataCons \(dataConName, argBs) -> do argBs' <- sourceRenameE argBs return (dataConName, argBs') - return $ UDataDef tyConName (expls, paramBs') dataCons' + return $ UDataDef tyConName paramBs' dataCons' instance SourceRenamableE UStructDef where - sourceRenameE (UStructDef tyConName (expls, paramBs) fields methods) = do + sourceRenameE (UStructDef tyConName paramBs fields methods) = do sourceRenameB paramBs \paramBs' -> do fields' <- forM fields \(fieldName, ty) -> do ty' <- sourceRenameE ty @@ -352,7 +326,7 @@ instance SourceRenamableE UStructDef where methods' <- forM methods \(ann, methodName, lam) -> do lam' <- sourceRenameE lam return (ann, methodName, lam') - return $ UStructDef tyConName (expls, paramBs') fields' methods' + return $ UStructDef tyConName paramBs' fields' methods' instance SourceRenamableE UDataDefTrail where sourceRenameE (UDataDefTrail args) = sourceRenameB args \args' -> @@ -371,19 +345,11 @@ instance SourceRenamableE e => SourceRenamableE (ListE e) where instance SourceRenamableE UnitE where sourceRenameE UnitE = return UnitE -instance SourceRenamableE UMethodDef' where - sourceRenameE (UMethodDef ~(SourceName pos v) expr) = do - lookupSourceName v >>= \case - UMethodVar v' -> UMethodDef (InternalName pos v v') <$> sourceRenameE expr - _ -> throw TypeErr $ "not a method name: " ++ pprint v - -instance SourceRenamableE UEffectOpDef where - sourceRenameE (UReturnOpDef expr) = do - UReturnOpDef <$> sourceRenameE expr - sourceRenameE (UEffectOpDef rp ~(SourceName pos v) expr) = do - lookupSourceName v >>= \case - UEffectOpVar v' -> UEffectOpDef rp (InternalName pos v v') <$> sourceRenameE expr - _ -> throw TypeErr $ "not an effect operation name: " ++ pprint v +instance SourceRenamableE UMethodDef where + sourceRenameE (WithSrcE sid ((UMethodDef ~(SourceName vSid v) expr))) = WithSrcE sid <$> do + lookupSourceName vSid v >>= \case + UMethodVar v' -> UMethodDef (InternalName vSid v v') <$> sourceRenameE expr + _ -> throw vSid $ NotAMethodName $ pprint v instance SourceRenamableB b => SourceRenamableB (Nest b) where sourceRenameB (Nest b bs) cont = @@ -407,31 +373,32 @@ class SourceRenamablePat (pat::B) where -> m o a instance SourceRenamablePat (UBinder (AtomNameC CoreIR)) where - sourceRenamePat sibs ubinder cont = do + sourceRenamePat sibs (WithSrcB sid ubinder) cont = do newSibs <- case ubinder of - UBindSource _ b -> do - when (S.member b sibs) $ throw RepeatedPatVarErr $ pprint b + UBindSource b -> do + when (S.member b sibs) $ throw sid $ RepeatedPatVarErr $ pprint b return $ S.singleton b UIgnore -> return mempty - UBind _ _ _ -> error "Shouldn't be source-renaming internal names" - sourceRenameB ubinder \ubinder' -> + UBind _ _ -> error "Shouldn't be source-renaming internal names" + sourceRenameB (WithSrcB sid ubinder) \ubinder' -> cont (sibs <> newSibs) ubinder' -instance SourceRenamablePat UPat' where - sourceRenamePat sibs pat cont = case pat of - UPatBinder b -> sourceRenamePat sibs b \sibs' b' -> cont sibs' $ UPatBinder b' +instance SourceRenamablePat UPat where + sourceRenamePat sibs (WithSrcB sid pat) cont = case pat of + UPatBinder b -> sourceRenamePat sibs b \sibs' b' -> + cont sibs' $ WithSrcB sid $ UPatBinder b' UPatCon con bs -> do -- TODO Deduplicate this against the code for sourceRenameE of -- the SourceName case of SourceNameOr con' <- sourceRenameE con sourceRenamePat sibs bs \sibs' bs' -> - cont sibs' $ UPatCon con' bs' + cont sibs' $ WithSrcB sid $ UPatCon con' bs' UPatDepPair (PairB p1 p2) -> sourceRenamePat sibs p1 \sibs' p1' -> sourceRenamePat sibs' p2 \sibs'' p2' -> - cont sibs'' $ UPatDepPair $ PairB p1' p2' - UPatProd bs -> sourceRenamePat sibs bs \sibs' bs' -> cont sibs' $ UPatProd bs' - UPatTable ps -> sourceRenamePat sibs ps \sibs' ps' -> cont sibs' $ UPatTable ps' + cont sibs'' $ WithSrcB sid $ UPatDepPair $ PairB p1' p2' + UPatProd bs -> sourceRenamePat sibs bs \sibs' bs' -> cont sibs' $ WithSrcB sid $ UPatProd bs' + UPatTable ps -> sourceRenamePat sibs ps \sibs' ps' -> cont sibs' $ WithSrcB sid $ UPatTable ps' instance SourceRenamablePat UnitB where sourceRenamePat sibs UnitB cont = cont sibs UnitB @@ -452,11 +419,6 @@ instance (SourceRenamablePat p1, SourceRenamablePat p2) sourceRenamePat sibs p \sibs' p' -> cont sibs' $ RightB p' -instance SourceRenamablePat p => SourceRenamablePat (WithSrcB p) where - sourceRenamePat sibs (WithSrcB pos pat) cont = addSrcContext pos do - sourceRenamePat sibs pat \sibs' pat' -> - cont sibs' $ WithSrcB pos pat' - instance SourceRenamablePat p => SourceRenamablePat (Nest p) where sourceRenamePat sibs (Nest b bs) cont = sourceRenamePat sibs b \sibs' b' -> @@ -464,7 +426,7 @@ instance SourceRenamablePat p => SourceRenamablePat (Nest p) where cont sibs'' $ Nest b' bs' sourceRenamePat sibs Empty cont = cont sibs Empty -instance SourceRenamableB UPat' where +instance SourceRenamableB UPat where sourceRenameB pat cont = sourceRenamePat mempty pat \_ pat' -> cont pat' @@ -482,16 +444,13 @@ class HasSourceNames (b::B) where instance HasSourceNames UTopDecl where sourceNames decl = case decl of ULocalDecl d -> sourceNames d - UDataDefDecl _ ~(UBindSource _ tyConName) dataConNames -> do + UDataDefDecl _ ~(WithSrcB _ (UBindSource tyConName)) dataConNames -> do S.singleton tyConName <> sourceNames dataConNames - UStructDecl ~(UBindSource _ tyConName) _ -> do + UStructDecl ~(WithSrcB _ (UBindSource tyConName)) _ -> do S.singleton tyConName - UInterface _ _ ~(UBindSource _ className) methodNames -> do + UInterface _ _ ~(WithSrcB _ (UBindSource className)) methodNames -> do S.singleton className <> sourceNames methodNames UInstance _ _ _ _ instanceName _ -> sourceNames instanceName - UEffectDecl _ ~(UBindSource _ effName) opNames -> do - S.singleton effName <> sourceNames opNames - UHandlerDecl _ _ _ _ _ _ handlerName -> sourceNames handlerName instance HasSourceNames UDecl' where sourceNames = \case @@ -524,11 +483,11 @@ instance HasSourceNames b => HasSourceNames (Nest b)where sourceNames (Nest b rest) = sourceNames b <> sourceNames rest -instance HasSourceNames (UBinder c) where +instance HasSourceNames (UBinder' c) where sourceNames b = case b of - UBindSource _ name -> S.singleton name + UBindSource name -> S.singleton name UIgnore -> mempty - UBind {} -> error "Shouldn't be source-renaming internal names" + UBind _ _ -> error "Shouldn't be source-renaming internal names" -- === misc instance === diff --git a/src/lib/Subst.hs b/src/lib/Subst.hs index 5b13ef624..b8124d360 100644 --- a/src/lib/Subst.hs +++ b/src/lib/Subst.hs @@ -16,10 +16,13 @@ import Control.Monad.Reader import Control.Monad.State.Strict import Name +import MTL1 import IRVariants import Types.Core +import Types.Top import Core import qualified RawName as R +import QueryTypePure import Err -- === SubstReader class === @@ -35,6 +38,10 @@ dropSubst :: (SubstReader v m, FromName v) => m o o a -> m i o a dropSubst cont = withSubst idSubst cont {-# INLINE dropSubst #-} +withVoidSubst :: (SubstReader v m, FromName v) => m VoidS o a -> m i o a +withVoidSubst cont = withSubst (newSubst absurdNameFunction) cont +{-# INLINE withVoidSubst #-} + extendSubst :: SubstReader v m => SubstFrag v i i' o -> m i' o a -> m i o a extendSubst frag cont = do env <- (<>>frag) <$> getSubst @@ -147,6 +154,12 @@ fromConstAbs (Abs b e) = hoist b e extendRenamer :: (SubstReader v m, FromName v) => SubstFrag Name i i' o -> m i' o r -> m i o r extendRenamer frag = extendSubst (fmapSubstFrag (const fromName) frag) +extendBinderRename + :: (SubstReader v m, FromName v, BindsAtMostOneName b c, BindsOneName b' c) + => b i i' -> b' o o' -> m i' o' r -> m i o' r +extendBinderRename b b' cont = extendSubst (b@>fromName (binderName b')) cont +{-# INLINE extendBinderRename #-} + applyRename :: (ScopeReader m, RenameE e, SinkableE e) => Ext h o => SubstFrag Name h i o -> e i -> m o (e o) @@ -267,6 +280,17 @@ instance ToSubstVal (SubstVal atom) atom where type AtomSubstReader v m = (SubstReader v m, FromName v, ToSubstVal v Atom) +toAtomVar :: (EnvReader m, IRRep r) => AtomName r n -> m n (AtomVar r n) +toAtomVar v = do + ty <- getType <$> lookupAtomName v + return $ AtomVar v ty + +lookupAtomSubst :: (IRRep r, SubstReader AtomSubstVal m, EnvReader2 m) => AtomName r i -> m i o (Atom r o) +lookupAtomSubst v = do + lookupSubstM v >>= \case + Rename v' -> toAtom <$> toAtomVar v' + SubstVal x -> return x + atomSubstM :: (AtomSubstReader v m, EnvReader2 m, SinkableE e, SubstE AtomSubstVal e) => e i -> m i o (e o) atomSubstM e = do @@ -280,36 +304,42 @@ asAtomSubstValSubst subst = newSubst \v -> toSubstVal (subst ! v) -- === SubstReaderT transformer === newtype SubstReaderT (v::V) (m::MonadKind1) (i::S) (o::S) (a:: *) = - SubstReaderT { runSubstReaderT' :: ReaderT (Subst v i o) (m o) a } + SubstReaderT' { runSubstReaderT' :: ReaderT (Subst v i o) (m o) a } + +pattern SubstReaderT :: (Subst v i o -> m o a) -> SubstReaderT v m i o a +pattern SubstReaderT f = SubstReaderT' (ReaderT f) + +runSubstReaderT :: Subst v i o -> SubstReaderT v m i o a -> m o a +runSubstReaderT env m = runReaderT (runSubstReaderT' m) env +{-# INLINE runSubstReaderT #-} instance (forall n. Functor (m n)) => Functor (SubstReaderT v m i o) where - fmap f (SubstReaderT m) = SubstReaderT $ fmap f m + fmap f (SubstReaderT' m) = SubstReaderT' $ fmap f m {-# INLINE fmap #-} instance Monad1 m => Applicative (SubstReaderT v m i o) where - pure = SubstReaderT . pure + pure = SubstReaderT' . pure {-# INLINE pure #-} - liftA2 f (SubstReaderT x) (SubstReaderT y) = SubstReaderT $ liftA2 f x y + liftA2 f (SubstReaderT' x) (SubstReaderT' y) = SubstReaderT' $ liftA2 f x y {-# INLINE liftA2 #-} - (SubstReaderT f) <*> (SubstReaderT x) = SubstReaderT $ f <*> x + (SubstReaderT' f) <*> (SubstReaderT' x) = SubstReaderT' $ f <*> x {-# INLINE (<*>) #-} instance (forall n. Monad (m n)) => Monad (SubstReaderT v m i o) where - return = SubstReaderT . return + return = SubstReaderT' . return {-# INLINE return #-} - (SubstReaderT m) >>= f = SubstReaderT (m >>= (runSubstReaderT' . f)) + (SubstReaderT' m) >>= f = SubstReaderT' (m >>= (runSubstReaderT' . f)) {-# INLINE (>>=) #-} deriving instance (Monad1 m, MonadFail1 m) => MonadFail (SubstReaderT v m i o) deriving instance (Monad1 m, Alternative1 m) => Alternative (SubstReaderT v m i o) -deriving instance (Fallible1 m) => Fallible (SubstReaderT v m i o) +deriving instance Fallible1 m => Fallible (SubstReaderT v m i o) deriving instance Catchable1 m => Catchable (SubstReaderT v m i o) -deriving instance CtxReader1 m => CtxReader (SubstReaderT v m i o) type ScopedSubstReader (v::V) = SubstReaderT v (ScopeReaderT Identity) :: MonadKind2 liftSubstReaderT :: Monad1 m => m o a -> SubstReaderT v m i o a -liftSubstReaderT m = SubstReaderT $ lift m +liftSubstReaderT m = SubstReaderT' $ lift m {-# INLINE liftSubstReaderT #-} runScopedSubstReader :: Distinct o => Scope o -> Subst v i o @@ -318,39 +348,43 @@ runScopedSubstReader scope env m = runIdentity $ runScopeReaderT scope $ runSubstReaderT env m {-# INLINE runScopedSubstReader #-} -runSubstReaderT :: Subst v i o -> SubstReaderT v m i o a -> m o a -runSubstReaderT env m = runReaderT (runSubstReaderT' m) env -{-# INLINE runSubstReaderT #-} - withSubstReaderT :: FromName v => SubstReaderT v m n n a -> m n a withSubstReaderT = runSubstReaderT idSubst {-# INLINE withSubstReaderT #-} instance (SinkableV v, Monad1 m) => SubstReader v (SubstReaderT v m) where - getSubst = SubstReaderT ask + getSubst = SubstReaderT' ask {-# INLINE getSubst #-} - withSubst env (SubstReaderT cont) = SubstReaderT $ withReaderT (const env) cont + withSubst env (SubstReaderT' cont) = SubstReaderT' $ withReaderT (const env) cont {-# INLINE withSubst #-} instance (SinkableV v, ScopeReader m) => ScopeReader (SubstReaderT v m i) where - unsafeGetScope = SubstReaderT $ lift unsafeGetScope + unsafeGetScope = liftSubstReaderT unsafeGetScope {-# INLINE unsafeGetScope #-} - getDistinct = SubstReaderT $ lift getDistinct + getDistinct = liftSubstReaderT getDistinct {-# INLINE getDistinct #-} instance (SinkableV v, EnvReader m) => EnvReader (SubstReaderT v m i) where - unsafeGetEnv = SubstReaderT $ lift unsafeGetEnv + unsafeGetEnv = liftSubstReaderT unsafeGetEnv {-# INLINE unsafeGetEnv #-} instance (SinkableV v, ScopeReader m, EnvExtender m) => EnvExtender (SubstReaderT v m i) where - refreshAbs ab cont = SubstReaderT $ ReaderT \subst -> + refreshAbs ab cont = SubstReaderT \subst -> refreshAbs ab \b e -> do subst' <- sinkM subst - let SubstReaderT (ReaderT cont') = cont b e + let SubstReaderT cont' = cont b e cont' subst' {-# INLINE refreshAbs #-} +instance MonadDiffState1 m s d => MonadDiffState1 (SubstReaderT v m i) s d where + withDiffState s m = + SubstReaderT \subst -> do + withDiffState s $ runSubstReaderT subst m + + updateDiffStateM d = liftSubstReaderT $ updateDiffStateM d + getDiffState = liftSubstReaderT getDiffState + type SubstEnvReaderM v = SubstReaderT v EnvReaderM :: MonadKind2 liftSubstEnvReaderM @@ -362,25 +396,24 @@ liftSubstEnvReaderM cont = liftEnvReaderM $ runSubstReaderT idSubst $ cont instance (SinkableV v, ScopeReader m, ScopeExtender m) => ScopeExtender (SubstReaderT v m i) where - refreshAbsScope ab cont = SubstReaderT $ ReaderT \env -> + refreshAbsScope ab cont = SubstReaderT \env -> refreshAbsScope ab \b e -> do - let SubstReaderT (ReaderT cont') = cont b e + let SubstReaderT cont' = cont b e env' <- sinkM env cont' env' instance (SinkableV v, MonadIO1 m) => MonadIO (SubstReaderT v m i o) where - liftIO m = SubstReaderT $ lift $ liftIO m + liftIO m = liftSubstReaderT $ liftIO m {-# INLINE liftIO #-} instance (Monad1 m, MonadState (s o) (m o)) => MonadState (s o) (SubstReaderT v m i o) where - state = SubstReaderT . lift . state + state = liftSubstReaderT . state {-# INLINE state #-} instance (Monad1 m, MonadReader (r o) (m o)) => MonadReader (r o) (SubstReaderT v m i o) where - ask = SubstReaderT $ ReaderT $ const ask + ask = SubstReaderT $ const ask {-# INLINE ask #-} - local r (SubstReaderT (ReaderT f)) = SubstReaderT $ ReaderT $ \env -> - local r $ f env + local r (SubstReaderT' (ReaderT f)) = SubstReaderT \env -> local r $ f env {-# INLINE local #-} -- === instances === @@ -466,6 +499,9 @@ instance FromName v => SubstE v (LiftE a) where instance SubstE v e => SubstE v (ListE e) where substE env (ListE xs) = ListE $ map (substE env) xs +instance SubstE v e => SubstE v (RListE e) where + substE env (RListE xs) = RListE $ fmap (substE env) xs + instance SubstE v e => SubstE v (NonEmptyListE e) where substE env (NonEmptyListE xs) = NonEmptyListE $ fmap (substE env) xs diff --git a/src/lib/TopLevel.hs b/src/lib/TopLevel.hs index f2379f259..d1932bbf8 100644 --- a/src/lib/TopLevel.hs +++ b/src/lib/TopLevel.hs @@ -8,12 +8,12 @@ module TopLevel ( EvalConfig (..), Topper, TopperM, runTopperM, - evalSourceBlock, evalSourceBlockRepl, OptLevel (..), + evalSourceBlockRepl, OptLevel (..), evalSourceText, TopStateEx (..), LibPath (..), evalSourceBlockIO, initTopState, loadCache, storeCache, clearCache, ensureModuleLoaded, importModule, printCodegen, loadObject, toCFunction, packageLLVMCallable, - simpOptimizations, loweredOptimizations, compileTopLevelFun) where + simpOptimizations, loweredOptimizations, compileTopLevelFun, ErrorHandling (..)) where import Data.Functor import Data.Maybe (catMaybes) @@ -26,7 +26,6 @@ import Data.Text (Text) import Data.Text.Prettyprint.Doc import Data.Store (encode, decode) import Data.String (fromString) -import Data.List (partition) import qualified Data.Map.Strict as M import qualified Data.Set as S import Foreign.Ptr @@ -34,7 +33,7 @@ import Foreign.C.String import GHC.Generics (Generic (..)) import System.FilePath import System.Directory -import System.IO (stderr, hPutStrLn, Handle) +import System.IO (stderr, hPutStrLn) import System.IO.Error (isDoesNotExistError) import LLVM.Link @@ -54,88 +53,79 @@ import Err import IRVariants import Imp import ImpToLLVM +import IncState import Inference import Inline -import Logging import Lower +import MonadUtil import MTL1 -import SourceInfo import Subst import Name import OccAnalysis import Optimize -import PPrint (pprintCanonicalized) import Paths_dex (getDataFileName) import QueryType import Runtime import Serialize (takePtrSnapshot, restorePtrSnapshot) import Simplify import SourceRename +import SourceIdTraversal +import PPrint import Types.Core import Types.Imp -import Types.Misc import Types.Primitives import Types.Source -import Util ( Tree (..), measureSeconds, File (..), readFileWithHash) +import Types.Top +import Util ( Tree (..), File (..), readFileWithHash) import Vectorize -- === top-level monad === data LibPath = LibDirectory FilePath | LibBuiltinPath +data ErrorHandling = HaltOnErr | ContinueOnErr data EvalConfig = EvalConfig { backendName :: Backend , libPaths :: [LibPath] , preludeFile :: Maybe FilePath - , logFileName :: Maybe FilePath - , logFile :: Maybe Handle , optLevel :: OptLevel - , printBackend :: PrintBackend } + , printBackend :: PrintBackend + , errorHandling :: ErrorHandling + , cfgLogLevel :: LogLevel + , cfgLogAction :: Outputs -> IO ()} class Monad m => ConfigReader m where getConfig :: m EvalConfig -data PassCtx = PassCtx - { requiresBench :: BenchRequirement - , shouldLogPass :: PassName -> Bool - } - -initPassCtx :: PassCtx -initPassCtx = PassCtx NoBench (const True) - -class Monad m => PassCtxReader m where - getPassCtx :: m PassCtx - withPassCtx :: PassCtx -> m a -> m a - class Monad m => RuntimeEnvReader m where getRuntimeEnv :: m RuntimeEnv -type TopLogger m = (MonadIO m, MonadLogger [Output] m) +type TopLogger m = (MonadIO m, Logger Outputs m) class ( forall n. Fallible (m n) - , forall n. MonadLogger [Output] (m n) + , forall n. Logger Outputs (m n) + , forall n. HasIOLogger Outputs (m n) , forall n. Catchable (m n) , forall n. ConfigReader (m n) - , forall n. PassCtxReader (m n) , forall n. RuntimeEnvReader (m n) , forall n. MonadIO (m n) -- TODO: something more restricted here , TopBuilder m ) => Topper m data TopperReaderData = TopperReaderData - { topperPassCtx :: PassCtx - , topperEvalConfig :: EvalConfig + { topperEvalConfig :: EvalConfig , topperRuntimeEnv :: RuntimeEnv } newtype TopperM (n::S) a = TopperM { runTopperM' - :: TopBuilderT (ReaderT TopperReaderData (LoggerT [Output] IO)) n a } + :: TopBuilderT (ReaderT TopperReaderData IO) n a } deriving ( Functor, Applicative, Monad, MonadIO, MonadFail , Fallible, EnvReader, ScopeReader, Catchable) -- Hides the `n` parameter as an existential data TopStateEx where TopStateEx :: Distinct n => Env n -> RuntimeEnv -> TopStateEx +instance Show TopStateEx where show _ = "TopStateEx" -- Hides the `n` parameter as an existential data TopSerializedStateEx where @@ -146,9 +136,8 @@ runTopperM -> (forall n. Mut n => TopperM n a) -> IO (a, TopStateEx) runTopperM opts (TopStateEx env rtEnv) cont = do - let maybeLogFile = logFile opts - (Abs frag (LiftE result), _) <- runLogger maybeLogFile \l -> runLoggerT l $ - flip runReaderT (TopperReaderData initPassCtx opts rtEnv) $ + Abs frag (LiftE result) <- + flip runReaderT (TopperReaderData opts rtEnv) $ runTopBuilderT env $ runTopperM' do localTopBuilder $ LiftE <$> cont return (result, extendTopEnv env rtEnv frag) @@ -171,45 +160,42 @@ allocateDynamicVarKeyPtrs = do -- ====== evalSourceBlockIO - :: EvalConfig -> TopStateEx -> SourceBlock -> IO (Result, TopStateEx) + :: EvalConfig -> TopStateEx -> SourceBlock -> IO TopStateEx evalSourceBlockIO opts env block = - runTopperM opts env $ evalSourceBlockRepl block + liftM snd $ runTopperM opts env $ evalSourceBlockRepl block -- Used for the top-level source file (rather than imported modules) -evalSourceText - :: (Topper m, Mut n) - => Text -> (SourceBlock -> IO ()) -> (Result -> IO Bool) - -> m n [(SourceBlock, Result)] -evalSourceText source beginCallback endCallback = do - let (UModule mname deps sbs) = parseUModule Main source +evalSourceText :: (Topper m, Mut n) => Text -> (SourceBlock -> IO ()) -> m n () +evalSourceText source logSourceBlock = do + let UModule mname deps sbs = parseUModule Main source mapM_ ensureModuleLoaded deps evalSourceBlocks mname sbs where evalSourceBlocks mname = \case - [] -> return [] - (sb:rest) -> do - liftIO $ beginCallback sb - result <- evalSourceBlock mname sb - liftIO (endCallback result) >>= \case - False -> return [(sb, result)] - True -> ((sb, result):) <$> evalSourceBlocks mname rest - -catchLogsAndErrs :: (Topper m, Mut n) => m n a -> m n (Except a, [Output]) -catchLogsAndErrs m = do - maybeLogFile <- logFile <$> getConfig - runLogger maybeLogFile \l -> withLogger l $ - catchErrExcept m + [] -> return () + sb:rest -> do + liftIO $ logSourceBlock sb + evalSourceBlock mname sb >>= \case + Success () -> return () + Failure e -> do + logTop $ Error e + (errorHandling <$> getConfig) >>= \case + HaltOnErr -> return () + ContinueOnErr -> evalSourceBlocks mname rest -- Module imports have to be handled differently in the repl because we don't -- know ahead of time which modules will be needed. -evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n Result +evalSourceBlockRepl :: (Topper m, Mut n) => SourceBlock -> m n () evalSourceBlockRepl block = do - case block of - SourceBlock _ _ _ _ (Misc (ImportModule name)) -> do + case sbContents block of + Misc (ImportModule name) -> do -- TODO: clear source map and synth candidates before calling this ensureModuleLoaded name _ -> return () - evalSourceBlock Main block + maybeErr <- evalSourceBlock Main block + case maybeErr of + Success () -> return () + Failure e -> logTop $ Error e -- XXX: This ensures that a module and its transitive dependencies are loaded, -- (which will require evaluating them if they're not in the cache) but it @@ -225,23 +211,18 @@ ensureModuleLoaded moduleSourceName = do {-# SCC ensureModuleLoaded #-} evalSourceBlock - :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n Result + :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n (Except ()) evalSourceBlock mname block = do - result <- withCompileTime do - (maybeErr, logs) <- catchLogsAndErrs do - benchReq <- getBenchRequirement block - withPassCtx (PassCtx benchReq (passLogFilter $ sbLogLevel block)) $ - evalSourceBlock' mname block - return $ Result logs maybeErr - case resultErrs result of - Failure _ -> case sbContents block of - TopDecl decl -> do - case runFallibleM (parseDecl decl) of - Success decl' -> emitSourceMap $ uDeclErrSourceMap mname decl' - Failure _ -> return () - _ -> return () + maybeErr <- catchErrExcept do + logTop $ SourceInfo $ SIGroupTree $ OverwriteWith $ getGroupTree $ sbContents block + evalSourceBlock' mname block + case (maybeErr, sbContents block) of + (Failure _, TopDecl decl) -> do + case parseDecl decl of + Success decl' -> emitSourceMap $ uDeclErrSourceMap mname decl' + Failure _ -> return () _ -> return () - return $ filterLogs block $ addResultCtx block result + return maybeErr evalSourceBlock' :: (Topper m, Mut n) => ModuleSourceName -> SourceBlock -> m n () @@ -266,7 +247,7 @@ evalSourceBlock' mname block = case sbContents block of s <- getDexString stringVal logTop $ TextOut s RenderHtml -> do - stringVal <- evalUExpr $ addTypeAnn expr (referTo "String") + stringVal <- evalUExpr $ addTypeAnn expr (referTo $ WithSrc (srcPos expr) "String") s <- getDexString stringVal logTop $ HtmlOut s ExportFun _ -> error "not implemented" @@ -278,39 +259,36 @@ evalSourceBlock' mname block = case sbContents block of -- logTop $ ExportedFun name f GetType -> do -- TODO: don't actually evaluate it val <- evalUExpr expr - ty <- cheapNormalize $ getType val - logTop $ TextOut $ pprintCanonicalized ty - DeclareForeign fname dexName cTy -> do - let b = fromString dexName :: UBinder (AtomNameC CoreIR) VoidS VoidS + logTop $ TextOut $ pprintCanonicalized $ getType val + DeclareForeign fname (WithSrc _ dexName) cTy -> do ty <- evalUType =<< parseExpr cTy asFFIFunType ty >>= \case - Nothing -> throw TypeErr + Nothing -> throwErr $ MiscErr $ MiscMiscErr "FFI functions must be n-ary first order functions with the IO effect" Just (impFunTy, naryPiTy) -> do -- TODO: query linking stuff and check the function is actually available - let hint = getNameHint b - fTop <- emitBinding hint $ TopFunBinding $ FFITopFun fname impFunTy + let hint = fromString $ pprint dexName + fTop <- emitBinding hint $ TopFunBinding $ FFITopFun (pprint $ withoutSrc fname) impFunTy vCore <- emitBinding hint $ AtomNameBinding $ FFIFunBound naryPiTy fTop - UBindSource _ sourceName <- return b emitSourceMap $ SourceMap $ - M.singleton sourceName [ModuleVar mname (Just $ UAtomVar vCore)] + M.singleton dexName [ModuleVar mname (Just $ UAtomVar vCore)] DeclareCustomLinearization fname zeros g -> do expr <- parseExpr g - lookupSourceMap fname >>= \case - Nothing -> throw UnboundVarErr $ pprint fname + lookupSourceMap (withoutSrc fname) >>= \case + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint fname Just (UAtomVar fname') -> do lookupCustomRules fname' >>= \case Nothing -> return () - Just _ -> throw TypeErr + Just _ -> throwErr $ MiscErr $ MiscMiscErr $ pprint fname ++ " already has a custom linearization" lookupAtomName fname' >>= \case NoinlineFun _ _ -> return () - _ -> throw TypeErr "Custom linearizations only apply to @noinline functions" + _ -> throwErr $ MiscErr $ MiscMiscErr "Custom linearizations only apply to @noinline functions" -- We do some special casing to avoid instantiating polymorphic functions. impl <- case expr of WithSrcE _ (UVar _) -> renameSourceNamesUExpr expr >>= \case - WithSrcE _ (UVar (InternalName _ _ (UAtomVar v))) -> Var <$> toAtomVar v + WithSrcE _ (UVar (InternalName _ _ (UAtomVar v))) -> toAtom <$> toAtomVar v _ -> error "Expected a variable" _ -> evalUExpr expr fType <- getType <$> toAtomVar fname' @@ -318,18 +296,20 @@ evalSourceBlock' mname block = case sbContents block of liftEnvReaderT (impl `checkTypeIs` linFunTy) >>= \case Failure _ -> do let implTy = getType impl - throw TypeErr $ unlines + throwErr $ MiscErr $ MiscMiscErr $ unlines [ "Expected the custom linearization to have type:" , "" , pprint linFunTy , "" , "but it has type:" , "" , pprint implTy] Success () -> return () updateTopEnv $ AddCustomRule fname' $ CustomLinearize nimplicit nexplicit zeros impl - Just _ -> throw TypeErr - $ "Custom linearization can only be defined for functions" - UnParseable _ s -> throw ParseErr s + Just _ -> throwErr $ MiscErr $ MiscMiscErr $ "Custom linearization can only be defined for functions" + UnParseable _ s -> throwErr $ ParseErr $ MiscParseErr s Misc m -> case m of GetNameType v -> do - ty <- cheapNormalize =<< sourceNameType v - logTop $ TextOut $ pprintCanonicalized ty + lookupSourceMap (withoutSrc v) >>= \case + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint v + Just uvar -> do + ty <- getUVarType uvar + logTop $ TextOut $ pprintCanonicalized ty ImportModule moduleName -> importModule moduleName QueryEnv query -> void $ runEnvQuery query $> UnitE ProseBlock _ -> return () @@ -337,11 +317,11 @@ evalSourceBlock' mname block = case sbContents block of EmptyLines -> return () where addTypeAnn :: UExpr n -> UExpr n -> UExpr n - addTypeAnn e = WithSrcE emptySrcPosCtx . UTypeAnn e + addTypeAnn e = WithSrcE (srcPos e) . UTypeAnn e addShowAny :: UExpr n -> UExpr n - addShowAny e = WithSrcE emptySrcPosCtx $ UApp (referTo "show_any") [e] [] - referTo :: SourceName -> UExpr n - referTo = WithSrcE emptySrcPosCtx . UVar . SourceName emptySrcPosCtx + addShowAny e = WithSrcE (srcPos e) $ UApp (referTo $ WithSrc (srcPos e) "show_any") [e] [] + referTo :: SourceNameW -> UExpr n + referTo (WithSrc sid name) = WithSrcE sid $ UVar $ SourceName sid name runEnvQuery :: Topper m => EnvQuery -> m n () runEnvQuery query = do @@ -350,11 +330,11 @@ runEnvQuery query = do DumpSubst -> logTop $ TextOut $ pprint $ env InternalNameInfo name -> case lookupSubstFragRaw (fromRecSubst $ envDefs $ topEnv env) name of - Nothing -> throw UnboundVarErr $ pprint name + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint name Just binding -> logTop $ TextOut $ pprint binding SourceNameInfo name -> do lookupSourceMap name >>= \case - Nothing -> throw UnboundVarErr $ pprint name + Nothing -> throw rootSrcId $ UnboundVarErr $ pprint name Just uvar -> do logTop $ TextOut $ pprint uvar info <- case uvar of @@ -363,19 +343,11 @@ runEnvQuery query = do UDataConVar v' -> pprint <$> lookupEnv v' UClassVar v' -> pprint <$> lookupEnv v' UMethodVar v' -> pprint <$> lookupEnv v' - UEffectVar v' -> pprint <$> lookupEnv v' - UEffectOpVar v' -> pprint <$> lookupEnv v' UPunVar v' -> do val <- lookupEnv v' return $ pprint val ++ "\n(type constructor and data constructor share the same name)" logTop $ TextOut $ "Binding:\n" ++ info -filterLogs :: SourceBlock -> Result -> Result -filterLogs block (Result outs err) = let - (logOuts, requiredOuts) = partition isLogInfo outs - outs' = requiredOuts ++ processLogs (sbLogLevel block) logOuts - in Result outs' err - -- returns a toposorted list of the module's transitive dependencies (including -- the module itself) excluding those provided in the set of already known -- modules. @@ -431,7 +403,7 @@ evalPartiallyParsedUModuleCached md@(UModulePartialParse name deps source) = do directDeps <- forM deps \dep -> do lookupLoadedModule dep >>= \case Just depVal -> return depVal - Nothing -> throw CompilerErr $ pprint dep ++ " isn't loaded" + Nothing -> throwInternal $ pprint dep ++ " isn't loaded" let req = (fHash source, directDeps) case M.lookup name cache of Just (cachedReq, result) | cachedReq == req -> return result @@ -465,7 +437,7 @@ evalUModule (UModule name _ blocks) = do importModule :: (Mut n, TopBuilder m, Fallible1 m) => ModuleSourceName -> m n () importModule name = do lookupLoadedModule name >>= \case - Nothing -> throw ModuleImportErr $ "Couldn't import " ++ pprint name + Nothing -> throwErr $ MiscErr $ ModuleImportErr $ pprint name Just name' -> do Module _ _ transImports' _ _ <- lookupModule name' let importStatus = ImportStatus (S.singleton name') @@ -473,58 +445,15 @@ importModule name = do emitLocalModuleEnv $ mempty { envImportStatus = importStatus } {-# SCC importModule #-} -passLogFilter :: LogLevel -> PassName -> Bool -passLogFilter = \case - LogAll -> const True - LogNothing -> const False - LogPasses passes -> (`elem` passes) - PrintEvalTime -> const False - PrintBench _ -> const False - -processLogs :: LogLevel -> [Output] -> [Output] -processLogs logLevel logs = case logLevel of - LogAll -> logs - LogNothing -> [] - LogPasses passes -> flip filter logs \case - PassInfo pass _ | pass `elem` passes -> True - | otherwise -> False - _ -> False - PrintEvalTime -> [BenchResult "" compileTime runTime benchStats] - where (compileTime, runTime, benchStats) = timesFromLogs logs - PrintBench benchName -> [BenchResult benchName compileTime runTime benchStats] - where (compileTime, runTime, benchStats) = timesFromLogs logs - -timesFromLogs :: [Output] -> (Double, Double, Maybe BenchStats) -timesFromLogs logs = (totalTime - totalEvalTime, singleEvalTime, benchStats) - where - (totalEvalTime, singleEvalTime, benchStats) = - case [(t, stats) | EvalTime t stats <- logs] of - [] -> (0.0 , 0.0, Nothing) - [(t, stats)] -> (total, t , stats) - where total = maybe t snd stats - _ -> error "Expect at most one result" - totalTime = case [tTotal | TotalTime tTotal <- logs] of - [] -> 0.0 - [t] -> t - _ -> error "Expect at most one result" - -isLogInfo :: Output -> Bool -isLogInfo out = case out of - PassInfo _ _ -> True - MiscLog _ -> True - EvalTime _ _ -> True - TotalTime _ -> True - _ -> False - evalUType :: (Topper m, Mut n) => UType VoidS -> m n (CType n) evalUType ty = do - logTop $ PassInfo Parse $ pprint ty + logDebug $ return $ PassInfo Parse $ pprint ty renamed <- logPass RenamePass $ renameSourceNamesUExpr ty checkPass TypePass $ checkTopUType renamed evalUExpr :: (Topper m, Mut n) => UExpr VoidS -> m n (CAtom n) evalUExpr expr = do - logTop $ PassInfo Parse $ pprint expr + logDebug $ return $ PassInfo Parse $ pprint expr renamed <- logPass RenamePass $ renameSourceNamesUExpr expr typed <- checkPass TypePass $ inferTopUExpr renamed evalBlock typed @@ -536,15 +465,10 @@ whenOpt x act = getConfig <&> optLevel >>= \case evalBlock :: (Topper m, Mut n) => TopBlock CoreIR n -> m n (CAtom n) evalBlock typed = do - -- Be careful when adding new compilation passes here. If you do, be sure to - -- also check compileTopLevelFun, below, and Export.prepareFunctionForExport. - -- In most cases it should be easiest to add new passes to simpOptimizations or - -- loweredOptimizations, below, because those are reused in all three places. - synthed <- checkPass SynthPass $ synthTopE typed - SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock synthed + SimplifiedTopLam simp recon <- checkPass SimpPass $ simplifyTopBlock typed opt <- simpOptimizations simp simpResult <- case opt of - TopLam _ _ (LamExpr Empty (WithoutDecls result)) -> return result + TopLam _ _ (LamExpr Empty (Atom result)) -> return result _ -> do lowered <- checkPass LowerPass $ lowerFullySequential True opt lOpt <- checkPass OptPass $ loweredOptimizations lowered @@ -572,8 +496,7 @@ loweredOptimizations lowered = do (dceTop >=> hoistLoopInvariant) whenOpt lopt \lo -> do (vo, errs) <- vectorizeLoops 64 lo - l <- getFilteredLogger - logFiltered l VectPass $ return [TextOut $ pprint errs] + logTop $ TextOut $ pprint errs checkPass VectPass $ return vo loweredOptimizationsNoDest :: Topper m => STopLam n -> m n (STopLam n) @@ -615,7 +538,7 @@ evalDictSpecializations ds = do execUDecl :: (Topper m, Mut n) => ModuleSourceName -> UTopDecl VoidS VoidS -> m n () execUDecl mname decl = do - logTop $ PassInfo Parse $ pprint decl + logDebug $ return $ PassInfo Parse $ pprint decl Abs renamedDecl sourceMap <- logPass RenamePass $ renameSourceNamesTopUDecl mname decl inferenceResult <- checkPass TypePass $ inferTopUDecl renamedDecl sourceMap @@ -647,9 +570,8 @@ compileTopLevelFun cc fSimp = do printCodegen :: (Topper m, Mut n) => CAtom n -> m n String printCodegen x = do - block <- liftBuilder $ buildBlock do - emitExpr $ PrimOp $ MiscOp $ ShowAny $ sink x - topBlock <- asTopBlock block + block <- liftBuilder $ buildBlock $ emit $ ShowAny $ sink x + (topBlock, _) <- asTopBlock block getDexString =<< evalBlock topBlock loadObject :: (Topper m, Mut n) => FunObjCodeName n -> m n NativeFunction @@ -688,7 +610,7 @@ linkFunObjCode objCode dyvarStores (LinktimeVals funVals ptrVals) = do toCFunction :: (Topper m, Mut n) => NameHint -> ImpFunction n -> m n (CFunction n) toCFunction nameHint impFun = do - logger <- getFilteredLogger + logger <- getIOLogger (closedImpFun, reqFuns, reqPtrNames) <- abstractLinktimeObjects impFun obj <- impToLLVM logger nameHint closedImpFun >>= compileToObjCode reqObjNames <- mapM funNameToObj reqFuns @@ -708,14 +630,13 @@ packageLLVMCallable :: forall n m. (Topper m, Mut n) => ImpFunction n -> m n LLVMCallable packageLLVMCallable impFun = do nativeFun <- toCFunction "main" impFun >>= loadObjectContent - benchRequired <- requiresBench <$> getPassCtx - logger <- getFilteredLogger + logger <- getIOLogger let IFunType _ _ resultTypes = impFunType impFun return LLVMCallable{..} compileToObjCode :: Topper m => WithCNameInterface LLVM.AST.Module -> m n FunObjCode compileToObjCode astWithNames = forM astWithNames \ast -> do - logger <- getFilteredLogger + logger <- getIOLogger opt <- getLLVMOptLevel <$> getConfig liftIO $ compileLLVM logger opt ast (cniMainFunName astWithNames) @@ -726,11 +647,6 @@ funNameToObj v = do TopFunBinding (DexTopFun _ _ (Finished impl)) -> return $ topFunObjCode impl b -> error $ "couldn't find object cache entry for " ++ pprint v ++ "\ngot:\n" ++ pprint b -withCompileTime :: MonadIO m => m Result -> m Result -withCompileTime m = do - (Result outs err, t) <- measureSeconds m - return $ Result (outs ++ [TotalTime t]) err - checkPass :: (Topper m, Pretty (e n), CheckableE r e) => PassName -> m n (e n) -> m n (e n) checkPass name cont = do @@ -738,35 +654,37 @@ checkPass name cont = do result <- cont return result #ifdef DEX_DEBUG - logTop $ MiscLog $ "Running checks" + logDebug $ return $ MiscLog $ "Running checks" checkTypes result - logTop $ MiscLog $ "Checks passed" + logDebug $ return $ MiscLog $ "Checks passed" #else - logTop $ MiscLog $ "Checks skipped (not a debug build)" + logDebug $ return $ MiscLog $ "Checks skipped (not a debug build)" #endif return result -addResultCtx :: SourceBlock -> Result -> Result -addResultCtx block (Result outs errs) = - Result outs (addSrcTextContext (sbOffset block) (sbText block) errs) - logTop :: TopLogger m => Output -> m () -logTop x = logIO [x] +logTop x = emitLog $ Outputs [x] + +logDebug :: TopLogger m => m Output -> m () +logDebug m = getLogLevel >>= \case + NormalLogLevel -> return () + DebugLogLevel -> do + x <- m + emitLog $ Outputs [x] logPass :: Topper m => Pretty a => PassName -> m n a -> m n a logPass passName cont = do - logTop $ PassInfo passName $ "=== " <> pprint passName <> " ===" - logTop $ MiscLog $ "Starting "++ pprint passName + logDebug $ return $ PassInfo passName $ "=== " <> pprint passName <> " ===" + logDebug $ return $ MiscLog $ "Starting "++ pprint passName result <- cont - {-# SCC logPassPrinting #-} logTop $ PassInfo passName - $ "=== Result ===\n" <> pprint result + logDebug $ return $ PassInfo passName $ "=== Result ===\n" <> pprint result return result loadModuleSource :: (MonadIO m, Fallible m) => EvalConfig -> ModuleSourceName -> m File loadModuleSource config moduleName = do fullPath <- case moduleName of - OrdinaryModule moduleName' -> findFullPath $ moduleName' ++ ".dx" + OrdinaryModule moduleName' -> findFullPath $ pprint moduleName' ++ ".dx" Prelude -> case preludeFile config of Nothing -> findFullPath "prelude.dx" Just path -> return path @@ -778,31 +696,16 @@ loadModuleSource config moduleName = do fsPaths <- liftIO $ traverse resolveBuiltinPath $ libPaths config liftIO (findFile fsPaths fname) >>= \case Just fpath -> return fpath - Nothing -> throw ModuleImportErr $ unlines - [ "Couldn't find a source file for module " ++ - (case moduleName of - OrdinaryModule n -> n; Prelude -> "prelude"; Main -> error "") - , "Hint: Consider extending --lib-path?" - ] - + Nothing -> throwErr $ MiscErr $ CantFindModuleSource $ pprint moduleName resolveBuiltinPath = \case LibBuiltinPath -> liftIO $ getDataFileName "lib" LibDirectory dir -> return dir {-# SCC loadModuleSource #-} -getBenchRequirement :: Topper m => SourceBlock -> m n BenchRequirement -getBenchRequirement block = case sbLogLevel block of - PrintBench _ -> do - backend <- backendName <$> getConfig - let needsSync = case backend of LLVMCUDA -> True - _ -> False - return $ DoBench needsSync - _ -> return NoBench - getDexString :: (MonadIO1 m, EnvReader m, Fallible1 m) => Val CoreIR n -> m n String getDexString val = do -- TODO: use a `ByteString` instead of `String` - SimpInCore (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val + Stuck _ (LiftSimp _ (RepValAtom (RepVal _ tree))) <- return val Branch [Leaf (IIdxRepVal n), Leaf (IPtrVar ptrName _)] <- return tree PtrBinding (CPU, Scalar Word8Type) (PtrLitVal ptr) <- lookupEnv ptrName liftIO $ peekCStringLen (castPtr ptr, fromIntegral n) @@ -881,20 +784,8 @@ restorePtrSnapshots s = traverseBindingsTopStateEx s \case PtrBinding ty p -> liftIO $ PtrBinding ty <$> restorePtrSnapshot p b -> return b -getFilteredLogger :: Topper m => m n PassLogger -getFilteredLogger = do - shouldLog <- shouldLogPass <$> getPassCtx - logger <- getLogger - return $ FilteredLogger shouldLog logger - -- === instances === -instance PassCtxReader (TopperM n) where - getPassCtx = TopperM $ asks topperPassCtx - withPassCtx ctx cont = TopperM $ - liftTopBuilderTWith (local \r -> r {topperPassCtx = ctx}) $ - runTopperM' cont - instance RuntimeEnvReader (TopperM n) where getRuntimeEnv = TopperM $ asks topperRuntimeEnv @@ -917,10 +808,14 @@ instance TopBuilder TopperM where emitNamelessEnv env = TopperM $ emitNamelessEnv env localTopBuilder cont = TopperM $ localTopBuilder $ runTopperM' cont -instance MonadLogger [Output] (TopperM n) where - getLogger = TopperM $ lift1 $ lift $ getLogger - withLogger l cont = - TopperM $ liftTopBuilderTWith (withLogger l) (runTopperM' cont) +instance Logger Outputs (TopperM n) where + emitLog x = do + logger <- getIOLogAction + liftIO $ logger x + getLogLevel = cfgLogLevel <$> getConfig + +instance HasIOLogger Outputs (TopperM n) where + getIOLogAction = cfgLogAction <$> getConfig instance Generic TopStateEx where type Rep TopStateEx = Rep (Env UnsafeS, RuntimeEnv) @@ -931,7 +826,7 @@ instance Generic TopStateEx where getLinearizationType :: SymbolicZeros -> CType n -> EnvReaderT Except n (Int, Int, CType n) getLinearizationType zeros = \case - Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy)) -> do + TyCon (Pi (CorePiType ExplicitApp expls bs (EffTy Pure resultTy))) -> do (numIs, numEs) <- getNumImplicits expls refreshAbs (Abs bs resultTy) \bs' resultTy' -> do PairB _ bsE <- return $ splitNestAt numIs bs' @@ -940,14 +835,14 @@ getLinearizationType zeros = \case Just tty -> case zeros of InstantiateZeros -> return tty SymbolicZeros -> symbolicTangentTy tty - Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint t + Nothing -> throwErr $ MiscErr $ MiscMiscErr $ "No tangent type for: " ++ pprint t resultTanTy <- maybeTangentType resultTy' >>= \case Just rtt -> return rtt - Nothing -> throw TypeErr $ "No tangent type for: " ++ pprint resultTy' - let tanFunTy = Pi $ nonDepPiType argTanTys Pure resultTanTy + Nothing -> throwErr $ MiscErr $ MiscMiscErr $ "No tangent type for: " ++ pprint resultTy' + let tanFunTy = toType $ Pi $ nonDepPiType argTanTys Pure resultTanTy let fullTy = CorePiType ExplicitApp expls bs' $ EffTy Pure (PairTy resultTy' tanFunTy) - return (numIs, numEs, Pi fullTy) - _ -> throw TypeErr $ "Can't define a custom linearization for implicit or impure functions" + return (numIs, numEs, toType $ Pi fullTy) + _ -> throwErr $ MiscErr $ MiscMiscErr $ "Can't define a custom linearization for implicit or impure functions" where getNumImplicits :: Fallible m => [Explicitness] -> m (Int, Int) getNumImplicits = \case @@ -958,4 +853,4 @@ getLinearizationType zeros = \case Inferred _ _ -> return (ni + 1, ne) Explicit -> case ni of 0 -> return (0, ne + 1) - _ -> throw TypeErr "All implicit args must precede implicit args" + _ -> throwErr $ MiscErr $ MiscMiscErr "All implicit args must precede implicit args" diff --git a/src/lib/Transpose.hs b/src/lib/Transpose.hs index 904e608d1..3e361d0d3 100644 --- a/src/lib/Transpose.hs +++ b/src/lib/Transpose.hs @@ -9,21 +9,18 @@ module Transpose (transpose, transposeTopFun) where import Data.Foldable import Data.Functor import Control.Category ((>>>)) -import Control.Monad.Reader -import qualified Data.Set as S import GHC.Stack import Builder import Core -import CheapReduction -import Err import Imp import IRVariants -import MTL1 import Name +import PPrint import Subst import QueryType import Types.Core +import Types.Top import Types.Primitives import Util (enumerate) @@ -34,35 +31,32 @@ transpose lam ct = liftEmitBuilder $ runTransposeM do UnaryLamExpr b body <- sinkM lam withAccumulator (binderType b) \refSubstVal -> extendSubst (b @> refSubstVal) $ - transposeBlock body (sink ct) + transposeExpr body (sink ct) {-# SCC transpose #-} runTransposeM :: TransposeM n n a -> BuilderM SimpIR n a -runTransposeM cont = runReaderT1 (ListE []) $ runSubstReaderT idSubst $ cont +runTransposeM cont = runSubstReaderT idSubst $ cont -transposeTopFun - :: (MonadFail1 m, EnvReader m) - => STopLam n -> m n (STopLam n) +transposeTopFun :: (MonadFail1 m, EnvReader m) => STopLam n -> m n (STopLam n) transposeTopFun (TopLam False _ lam) = liftBuilder $ runTransposeM do (Abs bsNonlin (Abs bLin body), Abs bsNonlin'' outTy) <- unpackLinearLamExpr lam refreshBinders bsNonlin \bsNonlin' substFrag -> extendRenamer substFrag do outTy' <- applyRename (bsNonlin''@@> nestToNames bsNonlin') outTy withFreshBinder "ct" outTy' \bCT -> do - let ct = Var $ binderVar bCT + let ct = toAtom $ binderVar bCT body' <- buildBlock do inTy <- substNonlin $ binderType bLin withAccumulator inTy \refSubstVal -> extendSubst (bLin @> refSubstVal) $ - transposeBlock body (sink ct) - EffTy _ bodyTy <- blockEffTy body' - let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure bodyTy) + transposeExpr body (sink ct) + let piTy = PiType (bsNonlin' >>> UnaryNest bCT) (EffTy Pure (getType body')) let lamT = LamExpr (bsNonlin' >>> UnaryNest bCT) body' return $ TopLam False piTy lamT transposeTopFun (TopLam True _ _) = error "shouldn't be transposing in destination passing style" unpackLinearLamExpr :: (MonadFail1 m, EnvReader m) => LamExpr SimpIR n - -> m n ( Abs (Nest SBinder) (Abs SBinder SBlock) n + -> m n ( Abs (Nest SBinder) (Abs SBinder SExpr) n , Abs (Nest SBinder) SType n) unpackLinearLamExpr lam@(LamExpr bs body) = do let numNonlin = nestLength bs - 1 @@ -75,55 +69,22 @@ unpackLinearLamExpr lam@(LamExpr bs body) = do -- === transposition monad === +type AtomTransposeSubstVal = TransposeSubstVal (AtomNameC SimpIR) data TransposeSubstVal c n where RenameNonlin :: Name c n -> TransposeSubstVal c n -- accumulator references corresponding to non-ref linear variables - LinRef :: SAtom n -> TransposeSubstVal (AtomNameC SimpIR) n + LinRef :: SAtom n -> AtomTransposeSubstVal n -- as an optimization, we don't make references for trivial vector spaces - LinTrivial :: TransposeSubstVal (AtomNameC SimpIR) n + LinTrivial :: AtomTransposeSubstVal n -type LinRegions = ListE SAtomVar +type TransposeM a = SubstReaderT TransposeSubstVal (BuilderM SimpIR) a -type TransposeM a = SubstReaderT TransposeSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a - -type TransposeM' a = SubstReaderT AtomSubstVal - (ReaderT1 LinRegions (BuilderM SimpIR)) a - --- TODO: it might make sense to replace substNonlin/isLin --- with a single `trySubtNonlin :: e i -> Maybe (e o)`. --- But for that we need a way to traverse names, like a monadic --- version of `substE`. -substNonlin :: (SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) +substNonlin :: (PrettyE e, SinkableE e, RenameE e, HasCallStack) => e i -> TransposeM i o (e o) substNonlin e = do subst <- getSubst fmapRenamingM (\v -> case subst ! v of RenameNonlin v' -> v' - _ -> error "not a nonlinear expression") e - --- TODO: Can we generalize onNonLin to accept SubstReaderT Name instead of --- SubstReaderT AtomSubstVal? For that to work, we need another combinator, --- that lifts a SubstReader AtomSubstVal into a SubstReader Name, because --- effectsSubstE is currently typed as SubstReader AtomSubstVal. --- Then we can presumably recode substNonlin as `onNonLin substM`. We may --- be able to do that anyway, except we will then need to restrict the type --- of substNonlin to require `SubstE AtomSubstVal e`; but that may be fine. -onNonLin :: HasCallStack - => TransposeM' i o a -> TransposeM i o a -onNonLin cont = do - subst <- getSubst - let subst' = newSubst (\v -> case subst ! v of - RenameNonlin v' -> Rename v' - _ -> error "not a nonlinear expression") - liftSubstReaderT $ runSubstReaderT subst' cont - -isLin :: HoistableE e => e i -> TransposeM i o Bool -isLin e = do - substVals <- mapM lookupSubstM $ freeAtomVarsList @SimpIR e - return $ flip any substVals \case - LinTrivial -> True - LinRef _ -> True - RenameNonlin _ -> False + _ -> error $ "not a nonlinear expression: " ++ pprint e) e withAccumulator :: Emits o @@ -135,7 +96,7 @@ withAccumulator ty cont = do Nothing -> do baseMonoid <- tangentBaseMonoidFor ty getSnd =<< emitRunWriter noHint ty baseMonoid \_ ref -> - cont (LinRef $ Var ref) >> return UnitVal + cont (LinRef $ toAtom ref) >> return UnitVal Just val -> do -- If the accumulator's type is inhabited by just one value, we -- don't need any actual accumulation, and can just return that @@ -147,48 +108,44 @@ withAccumulator ty cont = do emitCTToRef :: (Emits n, Builder SimpIR m) => SAtom n -> SAtom n -> m n () emitCTToRef ref ct = do baseMonoid <- tangentBaseMonoidFor (getType ct) - void $ emitOp $ RefOp ref $ MExtend baseMonoid ct - -getLinRegions :: TransposeM i o [SAtomVar o] -getLinRegions = asks fromListE - -extendLinRegions :: SAtomVar o -> TransposeM i o a -> TransposeM i o a -extendLinRegions v cont = local (\(ListE vs) -> ListE (v:vs)) cont + void $ emitLin $ RefOp ref $ MExtend baseMonoid ct -- === actual pass === -transposeBlock :: Emits o => SBlock i -> SAtom o -> TransposeM i o () -transposeBlock (Abs decls result) ct = transposeWithDecls decls result ct - -transposeWithDecls :: Emits o => Nest SDecl i i' -> SAtom i' -> SAtom o -> TransposeM i o () -transposeWithDecls Empty atom ct = transposeAtom atom ct -transposeWithDecls (Nest (Let b (DeclBinding _ expr)) rest) result ct = - substExprIfNonlin expr >>= \case - Nothing -> do - ty' <- substNonlin $ getType expr - ctExpr <- withAccumulator ty' \refSubstVal -> - extendSubst (b @> refSubstVal) $ - transposeWithDecls rest result (sink ct) - transposeExpr expr ctExpr - Just nonlinExpr -> do - v <- emit nonlinExpr - extendSubst (b @> RenameNonlin (atomVarName v)) $ - transposeWithDecls rest result ct - -substExprIfNonlin :: SExpr i -> TransposeM i o (Maybe (SExpr o)) -substExprIfNonlin expr = - isLin expr >>= \case - True -> return Nothing - False -> do - onNonLin (substM $ getEffects expr) >>= isLinEff >>= \case - True -> return Nothing - False -> Just <$> substNonlin expr +transposeWithDecls :: forall i i' o. Emits o => Nest SDecl i i' -> SExpr i' -> SAtom o -> TransposeM i o () +transposeWithDecls Empty atom ct = transposeExpr atom ct +transposeWithDecls (Nest (Let b (DeclBinding ann expr)) rest) result ct = case ann of + LinearLet -> do + ty' <- substNonlin $ getType expr + case expr of + Project _ i x -> do + continue =<< projectLinearRef x \ref -> emitLin =<< mkProjRef ref (ProjectProduct i) + TabApp _ x i -> do + continue =<< projectLinearRef x \ref -> do + i' <- substNonlin i + emitLin =<< mkIndexRef ref i' + _ -> do + ctExpr <- withAccumulator ty' \refSubstVal -> continue refSubstVal + transposeExpr expr ctExpr + _ -> do + v <- substNonlin expr >>= emitToVar + continue $ RenameNonlin (atomVarName v) + where + continue :: forall o'. (Emits o', Ext o o') => AtomTransposeSubstVal o' -> TransposeM i o' () + continue substVal = do + ct' <- sinkM ct + extendSubst (b @> substVal) $ transposeWithDecls rest result ct' -isLinEff :: EffectRow SimpIR o -> TransposeM i o Bool -isLinEff effs@(EffectRow _ NoTail) = do - regions <- fmap atomVarName <$> getLinRegions - let effRegions = freeAtomVarsList effs - return $ not $ null $ S.fromList effRegions `S.intersection` S.fromList regions +projectLinearRef + :: Emits o + => SAtom i -> (SAtom o -> TransposeM i o (SAtom o)) + -> TransposeM i o (AtomTransposeSubstVal o) +projectLinearRef x f = do + Stuck _ (Var v) <- return x + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> error "nonlinear" + LinRef ref -> LinRef <$> f ref + LinTrivial -> return LinTrivial getTransposedTopFun :: EnvReader m => TopFunName n -> m n (Maybe (TopFunName n)) getTransposedTopFun f = do @@ -197,6 +154,7 @@ getTransposedTopFun f = do transposeExpr :: Emits o => SExpr i -> SAtom o -> TransposeM i o () transposeExpr expr ct = case expr of + Block _ (Abs decls result) -> transposeWithDecls decls result ct Atom atom -> transposeAtom atom ct TopApp _ f xs -> do Just fT <- getTransposedTopFun =<< substNonlin f @@ -204,54 +162,30 @@ transposeExpr expr ct = case expr of xsNonlin' <- mapM substNonlin xsNonlin ct' <- naryTopApp fT (xsNonlin' ++ [ct]) transposeAtom xLin ct' - -- TODO: Instead, should we handle table application like nonlinear - -- expressions, where we just project the reference? - TabApp _ x is -> do - is' <- mapM substNonlin is - case x of - Var v -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "shouldn't happen" - LinRef ref -> do - refProj <- naryIndexRef ref (toList is') - emitCTToRef refProj ct - LinTrivial -> return () - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - refProj <- naryIndexRef ref' (toList is') - emitCTToRef refProj ct - LinTrivial -> return () - _ -> error $ "shouldn't occur: " ++ pprint x PrimOp op -> transposeOp op ct Case e alts _ -> do - linearScrutinee <- isLin e - case linearScrutinee of - True -> notImplemented - False -> do - e' <- substNonlin e - void $ buildCase e' UnitTy \i v -> do - v' <- emit (Atom v) - Abs b body <- return $ alts !! i - extendSubst (b @> RenameNonlin (atomVarName v')) do - transposeBlock body (sink ct) - return UnitVal + e' <- substNonlin e + void $ buildCase e' UnitTy \i v -> do + v' <- emitToVar v + Abs b body <- return $ alts !! i + extendSubst (b @> RenameNonlin (atomVarName v')) do + transposeExpr body (sink ct) + return UnitVal TabCon _ ty es -> do TabTy d b _ <- return ty idxTy <- substNonlin $ IxType (binderType b) d forM_ (enumerate es) \(ordinalIdx, e) -> do i <- unsafeFromOrdinal idxTy (IdxRepVal $ fromIntegral ordinalIdx) tabApp ct i >>= transposeAtom e + TabApp _ _ _ -> error "should have been handled by reference projection" + Project _ _ _ -> error "should have been handled by reference projection" transposeOp :: Emits o => PrimOp SimpIR i -> SAtom o -> TransposeM i o () transposeOp op ct = case op of DAMOp _ -> error "unreachable" -- TODO: rule out statically RefOp refArg m -> do refArg' <- substNonlin refArg - let emitEff = emitOp . RefOp refArg' + let emitEff = emitLin . RefOp refArg' case m of MAsk -> do baseMonoid <- tangentBaseMonoidFor (getType ct) @@ -269,18 +203,21 @@ transposeOp op ct = case op of ProjRef _ _ -> notImplemented Hof (TypedHof _ hof) -> transposeHof hof ct MiscOp miscOp -> transposeMiscOp miscOp ct - UnOp FNeg x -> transposeAtom x =<< neg ct + UnOp FNeg x -> transposeAtom x =<< (emitLin $ UnOp FNeg ct) UnOp _ _ -> notLinear BinOp FAdd x y -> transposeAtom x ct >> transposeAtom y ct - BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< neg ct) + BinOp FSub x y -> transposeAtom x ct >> (transposeAtom y =<< (emitLin $ UnOp FNeg ct)) + -- XXX: linear argument to FMul is always first BinOp FMul x y -> do - xLin <- isLin x - if xLin - then transposeAtom x =<< mul ct =<< substNonlin y - else transposeAtom y =<< mul ct =<< substNonlin x - BinOp FDiv x y -> transposeAtom x =<< div' ct =<< substNonlin y + y' <- substNonlin y + tx <- emitLin $ BinOp FMul ct y' + transposeAtom x tx + BinOp FDiv x y -> do + y' <- substNonlin y + tx <- emitLin $ BinOp FDiv ct y' + transposeAtom x tx BinOp _ _ _ -> notLinear - MemOp _ -> notLinear + MemOp _ -> notLinear VectorOp _ -> unreachable where notLinear = error $ "Can't transpose a non-linear operation: " ++ pprint op @@ -298,32 +235,23 @@ transposeMiscOp op _ = case op of BitcastOp _ _ -> notImplemented UnsafeCoerce _ _ -> notImplemented GarbageVal _ -> notImplemented - ShowAny _ -> error "Shouldn't have ShowAny in simplified IR" - ShowScalar _ -> error "Shouldn't have ShowScalar in simplified IR" - where - notLinear = error $ "Can't transpose a non-linear operation: " ++ show op + ShowAny _ -> notLinear + ShowScalar _ -> notLinear + where notLinear = error $ "Can't transpose a non-linear operation: " ++ show op transposeAtom :: HasCallStack => Emits o => SAtom i -> SAtom o -> TransposeM i o () transposeAtom atom ct = case atom of - Var v -> do - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> - -- XXX: we seem to need this case, but it feels like it should be an error! - return () - LinRef ref -> emitCTToRef ref ct - LinTrivial -> return () - Con con -> transposeCon con ct - DepPair _ _ _ -> notImplemented - PtrVar _ _ -> notTangent - ProjectElt _ i' x' -> do - let (idxs, v) = asNaryProj i' x' - lookupSubstM (atomVarName v) >>= \case - RenameNonlin _ -> error "an error, probably" - LinRef ref -> do - ref' <- getNaryProjRef (toList idxs) ref - emitCTToRef ref' ct - LinTrivial -> return () - RepValAtom _ -> error "not implemented" + Con con -> transposeCon con ct + Stuck _ stuck -> case stuck of + PtrVar _ _ -> notTangent + Var v -> do + lookupSubstM (atomVarName v) >>= \case + RenameNonlin _ -> error "nonlinear" + LinRef ref -> emitCTToRef ref ct + LinTrivial -> return () + StuckProject _ _ -> error "not linear" + StuckTabApp _ _ -> error "not linear" + RepValAtom _ -> error "not linear" where notTangent = error $ "Not a tangent atom: " ++ pprint atom transposeHof :: Emits o => Hof SimpIR i -> SAtom o -> TransposeM i o () @@ -331,16 +259,15 @@ transposeHof hof ct = case hof of For ann ixTy' lam -> do UnaryLamExpr b body <- return lam ixTy <- substNonlin ixTy' - void $ buildForAnn (getNameHint b) (flipDir ann) ixTy \i -> do - ctElt <- tabApp (sink ct) (Var i) - extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeBlock body ctElt + void $ emitLin =<< mkFor (getNameHint b) (flipDir ann) ixTy \i -> do + ctElt <- tabApp (sink ct) (toAtom i) + extendSubst (b@>RenameNonlin (atomVarName i)) $ transposeExpr body ctElt return UnitVal RunState Nothing s (BinaryLamExpr hB refB body) -> do (ctBody, ctState) <- fromPair ct (_, cts) <- (fromPair =<<) $ emitRunState noHint ctState \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeBlock body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal transposeAtom s cts RunReader r (BinaryLamExpr hB refB body) -> do @@ -348,8 +275,7 @@ transposeHof hof ct = case hof of baseMonoid <- tangentBaseMonoidFor accumTy (_, ct') <- (fromPair =<<) $ emitRunWriter noHint accumTy baseMonoid \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeBlock body (sink ct) + transposeExpr body (sink ct) return UnitVal transposeAtom r ct' RunWriter Nothing _ (BinaryLamExpr hB refB body)-> do @@ -357,8 +283,7 @@ transposeHof hof ct = case hof of (ctBody, ctEff) <- fromPair ct void $ emitRunReader noHint ctEff \h ref -> do extendSubst (hB@>RenameNonlin (atomVarName h)) $ extendSubst (refB@>RenameNonlin (atomVarName ref)) $ - extendLinRegions h $ - transposeBlock body (sink ctBody) + transposeExpr body (sink ctBody) return UnitVal _ -> notImplemented @@ -366,11 +291,10 @@ transposeCon :: Emits o => Con SimpIR i -> SAtom o -> TransposeM i o () transposeCon con ct = case con of Lit _ -> return () ProdCon [] -> return () - ProdCon xs -> - forM_ (enumerate xs) \(i, x) -> - projectTuple i ct >>= transposeAtom x + ProdCon xs -> forM_ (enumerate xs) \(i, x) -> proj i ct >>= transposeAtom x SumCon _ _ _ -> notImplemented HeapVal -> notTangent + DepPair _ _ _ -> notImplemented where notTangent = error $ "Not a tangent atom: " ++ pprint (Con con) notImplemented :: HasCallStack => a diff --git a/src/lib/TraverseSourceInfo.hs b/src/lib/TraverseSourceInfo.hs deleted file mode 100644 index 1265fbf51..000000000 --- a/src/lib/TraverseSourceInfo.hs +++ /dev/null @@ -1,127 +0,0 @@ --- Copyright 2022 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -{-# OPTIONS_GHC -Wno-incomplete-patterns #-} - -module TraverseSourceInfo (HasSourceInfo, gatherSourceInfo, addSpanIds) where - -import qualified Data.ByteString as BS -import Control.Monad.State -import Control.Monad.Writer -import GHC.Generics -import GHC.Int -import GHC.Word - -import Occurrence qualified as Occ -import SourceInfo -import Types.OpNames qualified as P -import Types.Primitives -import Types.Source - -class HasSourceInfo a where - traverseSourceInfo :: Applicative m => (SrcPosCtx -> m SrcPosCtx) -> a -> m a - - default traverseSourceInfo :: (Applicative m, Generic a, HasSourceInfo (Rep a Any)) => (SrcPosCtx -> m SrcPosCtx) -> a -> m a - traverseSourceInfo f x = to <$> traverseSourceInfo f (from x :: Rep a Any) - -tc :: HasSourceInfo a => Applicative m => (SrcPosCtx -> m SrcPosCtx) -> a -> m a -tc = traverseSourceInfo - -instance HasSourceInfo (V1 p) where - traverseSourceInfo _ x = pure x - -instance HasSourceInfo (U1 p) where - traverseSourceInfo _ x = pure x - -instance (HasSourceInfo c) => HasSourceInfo (K1 i c p) where - traverseSourceInfo f (K1 x) = K1 <$> traverseSourceInfo f x - -instance HasSourceInfo (f p) => HasSourceInfo (M1 i c f p) where - traverseSourceInfo f (M1 x) = M1 <$> traverseSourceInfo f x - -instance (HasSourceInfo (a p), HasSourceInfo (b p)) => HasSourceInfo ((a :+: b) p) where - traverseSourceInfo f (L1 x) = L1 <$> traverseSourceInfo f x - traverseSourceInfo f (R1 x) = R1 <$> traverseSourceInfo f x - -instance (HasSourceInfo (a p), HasSourceInfo (b p)) => HasSourceInfo ((a :*: b) p) where - traverseSourceInfo f (a :*: b) = (:*:) <$> traverseSourceInfo f a <*> traverseSourceInfo f b - -instance HasSourceInfo P.TC -instance HasSourceInfo P.Con -instance HasSourceInfo P.MemOp -instance HasSourceInfo P.VectorOp -instance HasSourceInfo P.MiscOp -instance HasSourceInfo PrimName -instance HasSourceInfo UnOp -instance HasSourceInfo BinOp -instance HasSourceInfo CmpOp -instance HasSourceInfo BaseType -instance HasSourceInfo ScalarBaseType -instance HasSourceInfo Device - -instance (HasSourceInfo a, HasSourceInfo b) => HasSourceInfo (a, b) -instance (HasSourceInfo a, HasSourceInfo b, HasSourceInfo c) => HasSourceInfo (a, b, c) -instance (HasSourceInfo a, HasSourceInfo b) => HasSourceInfo (Either a b) -instance HasSourceInfo a => HasSourceInfo [a] -instance HasSourceInfo a => HasSourceInfo (Maybe a) - -instance HasSourceInfo Occ.Count -instance HasSourceInfo Occ.UsageInfo -instance HasSourceInfo LetAnn -instance HasSourceInfo UResumePolicy -instance HasSourceInfo CInstanceDef -instance HasSourceInfo CTopDecl' - -instance HasSourceInfo AppExplicitness -instance HasSourceInfo CDef -instance HasSourceInfo CSDecl' -instance HasSourceInfo CSBlock -instance HasSourceInfo ForKind -instance HasSourceInfo Group' - -instance HasSourceInfo Bin' - -instance HasSourceInfo a => HasSourceInfo (WithSrc a) where - traverseSourceInfo f (WithSrc pos x) = WithSrc <$> f pos <*> tc f x - -instance HasSourceInfo () where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Char where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int32 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Int64 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word8 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word16 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word32 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Word64 where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Float where - traverseSourceInfo _ x = pure x -instance HasSourceInfo Double where - traverseSourceInfo _ x = pure x -instance HasSourceInfo BS.ByteString where - traverseSourceInfo _ x = pure x - --- The real base case. -instance HasSourceInfo SrcPosCtx where - traverseSourceInfo f x = f x - -gatherSourceInfo :: (HasSourceInfo a) => a -> [SrcPosCtx] -gatherSourceInfo x = execWriter (tc (\(ctx :: SrcPosCtx) -> tell [ctx] >> return ctx) x) - -addSpanIds :: (HasSourceInfo a) => a -> a -addSpanIds x = evalState (tc f x) 0 - where f (SrcPosCtx maybeSrcPos _) = do - currentId <- get - put (currentId + 1) - return (SrcPosCtx maybeSrcPos (Just currentId)) diff --git a/src/lib/Types/Core.hs b/src/lib/Types/Core.hs index 067af737e..daee75118 100644 --- a/src/lib/Types/Core.hs +++ b/src/lib/Types/Core.hs @@ -4,20 +4,8 @@ -- license that can be found in the LICENSE file or at -- https://developers.google.com/open-source/licenses/bsd -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE StrictData #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE TypeFamilies #-} -{-# LANGUAGE InstanceSigs #-} {-# LANGUAGE UndecidableInstances #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE DerivingStrategies #-} -{-# LANGUAGE DerivingVia #-} -{-# LANGUAGE DefaultSignatures #-} -{-# LANGUAGE ConstraintKinds #-} +{-# LANGUAGE StrictData #-} -- Core data types for CoreIR and its variations. @@ -25,20 +13,20 @@ module Types.Core (module Types.Core, SymbolicZeros (..)) where import Data.Word import Data.Maybe (fromJust) -import Data.Functor +import Data.Foldable (toList) import Data.Hashable -import Data.Text.Prettyprint.Doc hiding (nest) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc +import Data.Text (Text, unsnoc, uncons) import qualified Data.Map.Strict as M -import qualified Data.Set as S import GHC.Generics (Generic (..)) import Data.Store (Store (..)) -import Foreign.Ptr import Name -import Util (FileHash, SnocList (..), Tree (..)) +import Util (Tree (..)) import IRVariants -import SourceInfo +import PPrint import qualified Types.OpNames as P import Types.Primitives @@ -48,63 +36,91 @@ import Types.Imp -- === core IR === data Atom (r::IR) (n::S) where - Var :: AtomVar r n -> Atom r n - Con :: Con r n -> Atom r n - PtrVar :: PtrType -> PtrName n -> Atom r n - ProjectElt :: Type r n -> Projection -> Atom r n -> Atom r n - DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Atom r n - -- === CoreIR only === - Lam :: CoreLamExpr n -> Atom CoreIR n - Eff :: EffectRow CoreIR n -> Atom CoreIR n - DictCon :: Type CoreIR n -> DictExpr n -> Atom CoreIR n - NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Atom CoreIR n - DictHole :: AlwaysEqual SrcPosCtx -> Type CoreIR n -> RequiredMethodAccess - -> Atom CoreIR n - TypeAsAtom :: Type CoreIR n -> Atom CoreIR n - -- === Shims between IRs === - SimpInCore :: SimpInCore n -> Atom CoreIR n - RepValAtom :: RepVal SimpIR n -> Atom SimpIR n + Con :: Con r n -> Atom r n + Stuck :: Type r n -> Stuck r n -> Atom r n + deriving (Show, Generic) data Type (r::IR) (n::S) where - TC :: TC r n -> Type r n - TabPi :: TabPiType r n -> Type r n - DepPairTy :: DepPairType r n -> Type r n - TyVar :: AtomVar CoreIR n -> Type CoreIR n - DictTy :: DictType n -> Type CoreIR n - Pi :: CorePiType n -> Type CoreIR n - NewtypeTyCon :: NewtypeTyCon n -> Type CoreIR n - -- It was bad enough having this in `Atom`, but it's even worse now that it's - -- replicated in `Type` too. We should be able to remove both once - -- we represent types as normalized blocks. - ProjectEltTy :: CType n -> Projection -> CAtom n -> Type CoreIR n + TyCon :: TyCon r n -> Type r n + StuckTy :: CType n -> CStuck n -> Type CoreIR n + +data Dict (r::IR) (n::S) where + DictCon :: DictCon r n -> Dict r n + StuckDict :: CType n -> CStuck n -> Dict CoreIR n + +data Con (r::IR) (n::S) where + Lit :: LitVal -> Con r n + ProdCon :: [Atom r n] -> Con r n + SumCon :: [Type r n] -> Int -> Atom r n -> Con r n -- type, tag, payload + HeapVal :: Con r n + DepPair :: Atom r n -> Atom r n -> DepPairType r n -> Con r n + Lam :: CoreLamExpr n -> Con CoreIR n + Eff :: EffectRow CoreIR n -> Con CoreIR n + NewtypeCon :: NewtypeCon n -> Atom CoreIR n -> Con CoreIR n + DictConAtom :: DictCon CoreIR n -> Con CoreIR n + TyConAtom :: TyCon CoreIR n -> Con CoreIR n + +data Stuck (r::IR) (n::S) where + Var :: AtomVar r n -> Stuck r n + StuckProject :: Int -> Stuck r n -> Stuck r n + StuckTabApp :: Stuck r n -> Atom r n -> Stuck r n + PtrVar :: PtrType -> PtrName n -> Stuck r n + RepValAtom :: RepVal n -> Stuck SimpIR n + StuckUnwrap :: CStuck n -> Stuck CoreIR n + InstantiatedGiven :: CStuck n -> [CAtom n] -> Stuck CoreIR n + SuperclassProj :: Int -> CStuck n -> Stuck CoreIR n + LiftSimp :: CType n -> Stuck SimpIR n -> Stuck CoreIR n + LiftSimpFun :: CorePiType n -> LamExpr SimpIR n -> Stuck CoreIR n + -- TabLam and ACase are just defunctionalization tools. The result type + -- in both cases should *not* be `Data`. + TabLam :: TabLamExpr n -> Stuck CoreIR n + ACase :: SStuck n -> [Abs SBinder CAtom n] -> CType n -> Stuck CoreIR n + +data TyCon (r::IR) (n::S) where + BaseType :: BaseType -> TyCon r n + ProdType :: [Type r n] -> TyCon r n + SumType :: [Type r n] -> TyCon r n + RefType :: Atom r n -> Type r n -> TyCon r n + HeapType :: TyCon r n + TabPi :: TabPiType r n -> TyCon r n + DepPairTy :: DepPairType r n -> TyCon r n + TypeKind :: TyCon CoreIR n + DictTy :: DictType n -> TyCon CoreIR n + Pi :: CorePiType n -> TyCon CoreIR n + NewtypeTyCon :: NewtypeTyCon n -> TyCon CoreIR n data AtomVar (r::IR) (n::S) = AtomVar { atomVarName :: AtomName r n , atomVarType :: Type r n } deriving (Show, Generic) -type TabLamExpr = PairE (IxType SimpIR) (Abs (Binder SimpIR) (Abs (Nest SDecl) CAtom)) -data SimpInCore (n::S) = - LiftSimp (CType n) (SAtom n) - | LiftSimpFun (CorePiType n) (LamExpr SimpIR n) - | TabLam (TabPiType CoreIR n) (TabLamExpr n) - | ACase (SAtom n) [Abs SBinder CAtom n] (CType n) - deriving (Show, Generic) - -deriving instance IRRep r => Show (Atom r n) -deriving instance IRRep r => Show (Type r n) -deriving via WrapE (Atom r) n instance IRRep r => Generic (Atom r n) -deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) - +deriving instance IRRep r => Show (DictCon r n) +deriving instance IRRep r => Show (Dict r n) +deriving instance IRRep r => Show (Con r n) +deriving instance IRRep r => Show (TyCon r n) +deriving instance IRRep r => Show (Type r n) +deriving instance IRRep r => Show (Stuck r n) + +deriving via WrapE (DictCon r) n instance IRRep r => Generic (DictCon r n) +deriving via WrapE (Dict r) n instance IRRep r => Generic (Dict r n) +deriving via WrapE (Con r) n instance IRRep r => Generic (Con r n) +deriving via WrapE (TyCon r) n instance IRRep r => Generic (TyCon r n) +deriving via WrapE (Type r) n instance IRRep r => Generic (Type r n) +deriving via WrapE (Stuck r) n instance IRRep r => Generic (Stuck r n) + +-- TODO: factor out the EffTy and maybe merge with PrimOp data Expr r n where + Block :: EffTy r n -> Block r n -> Expr r n TopApp :: EffTy SimpIR n -> TopFunName n -> [SAtom n] -> Expr SimpIR n - TabApp :: Type r n -> Atom r n -> [Atom r n] -> Expr r n - Case :: Atom r n -> [Alt r n] -> EffTy r n -> Expr r n - Atom :: Atom r n -> Expr r n - TabCon :: Maybe (WhenCore r Dict n) -> Type r n -> [Atom r n] -> Expr r n - PrimOp :: PrimOp r n -> Expr r n - App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n - ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n + TabApp :: Type r n -> Atom r n -> Atom r n -> Expr r n + Case :: Atom r n -> [Alt r n] -> EffTy r n -> Expr r n + Atom :: Atom r n -> Expr r n + TabCon :: Maybe (WhenCore r (Dict CoreIR) n) -> Type r n -> [Atom r n] -> Expr r n + PrimOp :: PrimOp r n -> Expr r n + Project :: Type r n -> Int -> Atom r n -> Expr r n + App :: EffTy CoreIR n -> CAtom n -> [CAtom n] -> Expr CoreIR n + Unwrap :: CType n -> CAtom n -> Expr CoreIR n + ApplyMethod :: EffTy CoreIR n -> CAtom n -> Int -> [CAtom n] -> Expr CoreIR n deriving instance IRRep r => Show (Expr r n) deriving via WrapE (Expr r) n instance IRRep r => Generic (Expr r n) @@ -114,6 +130,9 @@ data BaseMonoid r n = , baseCombine :: LamExpr r n } deriving (Show, Generic) +data RepVal (n::S) = RepVal (SType n) (Tree (IExpr n)) + deriving (Show, Generic) + data DeclBinding r n = DeclBinding LetAnn (Expr r n) deriving (Show, Generic) data Decl (r::IR) (n::S) (l::S) = Let (AtomNameBinder r n l) (DeclBinding r n) @@ -141,7 +160,7 @@ type FunObjCodeName = Name FunObjCodeNameC type AtomBinderP (r::IR) = BinderP (AtomNameC r) type Binder r = AtomBinderP r (Type r) :: B -type Alt r = Abs (Binder r) (Block r) :: E +type Alt r = Abs (Binder r) (Expr r) :: E newtype DotMethods n = DotMethods (M.Map SourceName (CAtomName n)) deriving (Show, Generic, Monoid, Semigroup) @@ -175,30 +194,15 @@ data TyConParams n = TyConParams [Explicitness] [Atom CoreIR n] deriving (Show, Generic) type WithDecls (r::IR) = Abs (Decls r) :: E -> E -type Block (r::IR) = WithDecls r (Atom r) :: E - -type TopBlock = TopLam -- used for nullary lambda -type IsDestLam = Bool -data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) - deriving (Show, Generic) +type Block (r::IR) = WithDecls r (Expr r) :: E data LamExpr (r::IR) (n::S) where - LamExpr :: Nest (Binder r) n l -> Block r l -> LamExpr r n - -data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) + LamExpr :: Nest (Binder r) n l -> Expr r l -> LamExpr r n -data IxDict r n where - IxDictAtom :: Atom CoreIR n -> IxDict CoreIR n - -- TODO: make these two only available in SimpIR (currently we can't do that - -- because we need CoreIR to be a superset of SimpIR) - -- IxDictRawFin is used post-simplification. It behaves like `Fin`, but - -- it's parameterized by a newtyped-stripped `IxRepVal` instead of `Nat`, and - -- it describes indices of type `IxRepVal`. - IxDictRawFin :: Atom r n -> IxDict r n - IxDictSpecialized :: SType n -> SpecDictName n -> [SAtom n] -> IxDict SimpIR n +data CoreLamExpr (n::S) = CoreLamExpr (CorePiType n) (LamExpr CoreIR n) deriving (Show, Generic) -deriving instance IRRep r => Show (IxDict r n) -deriving via WrapE (IxDict r) n instance IRRep r => Generic (IxDict r n) +type TabLamExpr = PairE (TabPiType CoreIR) (Abs SBinder CAtom) +type IxDict = Dict data IxMethod = Size | Ordinal | UnsafeFromOrdinal deriving (Show, Generic, Enum, Bounded, Eq) @@ -222,7 +226,6 @@ data DepPairType (r::IR) (n::S) where type Val = Atom type Kind = Type -type Dict = Atom CoreIR -- A nest where the annotation of a binder cannot depend on the binders -- introduced before it. You can think of it as introducing a bunch of @@ -239,7 +242,7 @@ class ToBindersAbs (e::E) (body::E) (r::IR) | e -> body, e -> r where instance ToBindersAbs CorePiType (EffTy CoreIR) CoreIR where toAbs (CorePiType _ _ bs effTy) = Abs bs effTy -instance ToBindersAbs CoreLamExpr (Block CoreIR) CoreIR where +instance ToBindersAbs CoreLamExpr (Expr CoreIR) CoreIR where toAbs (CoreLamExpr _ lam) = toAbs lam instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where @@ -248,7 +251,7 @@ instance ToBindersAbs (Abs (Nest (Binder r)) body) body r where instance ToBindersAbs (PiType r) (EffTy r) r where toAbs (PiType bs effTy) = Abs bs effTy -instance ToBindersAbs (LamExpr r) (Block r) r where +instance ToBindersAbs (LamExpr r) (Expr r) r where toAbs (LamExpr bs body) = Abs bs body instance ToBindersAbs (TabPiType r) (Type r) r where @@ -264,19 +267,10 @@ instance ToBindersAbs TyConDef DataConDefs CoreIR where toAbs (TyConDef _ _ bs body) = Abs bs body instance ToBindersAbs ClassDef (Abs (Nest CBinder) (ListE CorePiType)) CoreIR where - toAbs (ClassDef _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) - -instance ToBindersAbs (TopLam r) (Block r) r where - toAbs (TopLam _ _ lam) = toAbs lam + toAbs (ClassDef _ _ _ _ _ bs scBs tys) = Abs bs (Abs scBs (ListE tys)) -- === GenericOp class === -class IsPrimOp (e::IR->E) where - toPrimOp :: e r n -> PrimOp r n - -instance IsPrimOp PrimOp where - toPrimOp x = x - class GenericOp (e::IR->E) where type OpConst e (r::IR) :: * fromOp :: e r n -> GenericOpRep (OpConst e r) r n @@ -323,22 +317,6 @@ traverseOp op fType fAtom fLam = do -- === Various ops === -data TC (r::IR) (n::S) where - BaseType :: BaseType -> TC r n - ProdType :: [Type r n] -> TC r n - SumType :: [Type r n] -> TC r n - RefType :: Atom r n -> Type r n -> TC r n - TypeKind :: TC r n -- TODO: `HasCore r` constraint - HeapType :: TC r n - deriving (Show, Generic) - -data Con (r::IR) (n::S) where - Lit :: LitVal -> Con r n - ProdCon :: [Atom r n] -> Con r n - SumCon :: [Type r n] -> Int -> Atom r n -> Con r n -- type, tag, payload - HeapVal :: Con r n - deriving (Show, Generic) - data PrimOp (r::IR) (n::S) where UnOp :: P.UnOp -> Atom r n -> PrimOp r n BinOp :: P.BinOp -> Atom r n -> Atom r n -> PrimOp r n @@ -396,13 +374,13 @@ data TypedHof r n = TypedHof (EffTy r n) (Hof r n) data Hof r n where For :: ForAnn -> IxType r n -> LamExpr r n -> Hof r n - While :: Block r n -> Hof r n + While :: Expr r n -> Hof r n RunReader :: Atom r n -> LamExpr r n -> Hof r n RunWriter :: Maybe (Atom r n) -> BaseMonoid r n -> LamExpr r n -> Hof r n RunState :: Maybe (Atom r n) -> Atom r n -> LamExpr r n -> Hof r n -- dest, initial value, body lambda - RunIO :: Block r n -> Hof r n - RunInit :: Block r n -> Hof r n - CatchException :: CType n -> Block CoreIR n -> Hof CoreIR n + RunIO :: Expr r n -> Hof r n + RunInit :: Expr r n -> Hof r n + CatchException :: CType n -> Expr CoreIR n -> Hof CoreIR n Linearize :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n Transpose :: LamExpr CoreIR n -> Atom CoreIR n -> Hof CoreIR n @@ -431,6 +409,8 @@ data RefOp r n = type CAtom = Atom CoreIR type CType = Type CoreIR +type CDict = Dict CoreIR +type CStuck = Stuck CoreIR type CBinder = Binder CoreIR type CExpr = Expr CoreIR type CBlock = Block CoreIR @@ -438,10 +418,11 @@ type CDecl = Decl CoreIR type CDecls = Decls CoreIR type CAtomName = AtomName CoreIR type CAtomVar = AtomVar CoreIR -type CTopLam = TopLam CoreIR type SAtom = Atom SimpIR type SType = Type SimpIR +type SDict = Dict SimpIR +type SStuck = Stuck SimpIR type SExpr = Expr SimpIR type SBlock = Block SimpIR type SAlt = Alt SimpIR @@ -450,9 +431,7 @@ type SDecls = Decls SimpIR type SAtomName = AtomName SimpIR type SAtomVar = AtomVar SimpIR type SBinder = Binder SimpIR -type SRepVal = RepVal SimpIR type SLam = LamExpr SimpIR -type STopLam = TopLam SimpIR -- === newtypes === @@ -470,9 +449,6 @@ data NewtypeTyCon (n::S) = | UserADTType SourceName (TyConName n) (TyConParams n) deriving (Show, Generic) -pattern TypeCon :: SourceName -> TyConName n -> TyConParams n -> CType n -pattern TypeCon s d xs = NewtypeTyCon (UserADTType s d xs) - isSumCon :: NewtypeCon n -> Bool isSumCon = \case UserADTData _ _ _ -> True @@ -485,6 +461,7 @@ type RoleExpl = (ParamRole, Explicitness) data ClassDef (n::S) where ClassDef :: SourceName -- name of class + -> Maybe BuiltinClassName -> [SourceName] -- method source names -> [Maybe SourceName] -- parameter source names -> [RoleExpl] -- parameter info @@ -493,6 +470,8 @@ data ClassDef (n::S) where -> [CorePiType n3] -- method types -> ClassDef n1 +data BuiltinClassName = Data | Ix deriving (Show, Generic, Eq) + data InstanceDef (n::S) where InstanceDef :: ClassName n1 @@ -508,189 +487,23 @@ data InstanceBody (n::S) = [CAtom n] -- method definitions deriving (Show, Generic) -data DictType (n::S) = DictType SourceName (ClassName n) [CAtom n] - deriving (Show, Generic) - -data DictExpr (n::S) = - InstantiatedGiven (CAtom n) [CAtom n] - | SuperclassProj (CAtom n) Int -- (could instantiate here too, but we don't need it for now) - -- We use NonEmpty because givens without args can be represented using `Var`. - | InstanceDict (InstanceName n) [CAtom n] - -- Special case for `Ix (Fin n)` (TODO: a more general mechanism for built-in classes and instances) - | IxFin (CAtom n) - -- Special case for `Data ` - | DataData (CType n) +data DictType (n::S) = + DictType SourceName (ClassName n) [CAtom n] + | IxDictType (CType n) + | DataDictType (CType n) deriving (Show, Generic) --- TODO: Use an IntMap -newtype CustomRules (n::S) = - CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } - deriving (Semigroup, Monoid, Store) -data AtomRules (n::S) = - -- number of implicit args, number of explicit args, linearization function - CustomLinearize Int Int SymbolicZeros (CAtom n) - deriving (Generic) - --- === Runtime representations === - -data RepVal (r::IR) (n::S) = RepVal (Type r n) (Tree (IExpr n)) - deriving (Show, Generic) - --- === envs and modules === - --- `ModuleEnv` contains data that only makes sense in the context of evaluating --- a particular module. `TopEnv` contains everything that makes sense "between" --- evaluating modules. -data Env n = Env - { topEnv :: {-# UNPACK #-} TopEnv n - , moduleEnv :: {-# UNPACK #-} ModuleEnv n } - deriving (Generic) - -data TopEnv (n::S) = TopEnv - { envDefs :: RecSubst Binding n - , envCustomRules :: CustomRules n - , envCache :: Cache n - , envLoadedModules :: LoadedModules n - , envLoadedObjects :: LoadedObjects n } - deriving (Generic) - -data SerializedEnv n = SerializedEnv - { serializedEnvDefs :: RecSubst Binding n - , serializedEnvCustomRules :: CustomRules n - , serializedEnvCache :: Cache n } - deriving (Generic) - --- TODO: consider splitting this further into `ModuleEnv` (the env that's --- relevant between top-level decls) and `LocalEnv` (the additional parts of the --- env that's relevant under a lambda binder). Unlike the Top/Module --- distinction, there's some overlap. For example, instances can be defined at --- both the module-level and local level. Similarly, if we start allowing --- top-level effects in `Main` then we'll have module-level effects and local --- effects. -data ModuleEnv (n::S) = ModuleEnv - { envImportStatus :: ImportStatus n - , envSourceMap :: SourceMap n - , envSynthCandidates :: SynthCandidates n } - deriving (Generic) - -data Module (n::S) = Module - { moduleSourceName :: ModuleSourceName - , moduleDirectDeps :: S.Set (ModuleName n) - , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself - , moduleExports :: SourceMap n - -- these are just the synth candidates required by this - -- module by itself. We'll usually also need those required by the module's - -- (transitive) dependencies, which must be looked up separately. - , moduleSynthCandidates :: SynthCandidates n } - deriving (Show, Generic) - -data LoadedModules (n::S) = LoadedModules - { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} - deriving (Show, Generic) +data DictCon (r::IR) (n::S) where + InstanceDict :: CType n -> InstanceName n -> [CAtom n] -> DictCon CoreIR n + -- Special case for `Data ` + DataData :: CType n -> DictCon CoreIR n + IxFin :: CAtom n -> DictCon CoreIR n + -- IxRawFin is like `Fin`, but it's parameterized by a newtyped-stripped + -- `IxRepVal` instead of `Nat`, and it describes indices of type `IxRepVal`. + -- TODO: make is SimpIR-only + IxRawFin :: Atom r n -> DictCon r n + IxSpecialized :: SpecDictName n -> [SAtom n] -> DictCon SimpIR n -emptyModuleEnv :: ModuleEnv n -emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty - -emptyLoadedModules :: LoadedModules n -emptyLoadedModules = LoadedModules mempty - -data LoadedObjects (n::S) = LoadedObjects - -- the pointer points to the actual runtime function - { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} - deriving (Show, Generic) - -emptyLoadedObjects :: LoadedObjects n -emptyLoadedObjects = LoadedObjects mempty - -data ImportStatus (n::S) = ImportStatus - { directImports :: S.Set (ModuleName n) - -- XXX: This are cached for efficiency. It's derivable from `directImports`. - , transImports :: S.Set (ModuleName n) } - deriving (Show, Generic) - -data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) - -data TopEnvUpdate n = - ExtendCache (Cache n) - | AddCustomRule (CAtomName n) (AtomRules n) - | UpdateLoadedModules ModuleSourceName (ModuleName n) - | UpdateLoadedObjects (FunObjCodeName n) NativeFunction - | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] - | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) - | UpdateInstanceDef (InstanceName n) (InstanceDef n) - | UpdateTyConDef (TyConName n) (TyConDef n) - | UpdateFieldDef (TyConName n) SourceName (CAtomName n) - --- TODO: we could add a lot more structure for querying by dict type, caching, etc. --- TODO: make these `Name n` instead of `Atom n` so they're usable as cache keys. -data SynthCandidates n = SynthCandidates - { lambdaDicts :: [AtomName CoreIR n] - , instanceDicts :: M.Map (ClassName n) [InstanceName n] } - deriving (Show, Generic) - -emptyImportStatus :: ImportStatus n -emptyImportStatus = ImportStatus mempty mempty - --- TODO: figure out the additional top-level context we need -- backend, other --- compiler flags etc. We can have a map from those to this. - -data Cache (n::S) = Cache - { specializationCache :: EMap SpecializationSpec TopFunName n - , ixDictCache :: EMap AbsDict SpecDictName n - , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n - , transpositionCache :: EMap TopFunName TopFunName n - -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we - -- only want to store one entry per module name as a simple cache eviction - -- policy, so we store it keyed on the module name, with the text hash for - -- the validity check. - , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) - , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) - } deriving (Show, Generic) - --- === runtime function and variable representations === - -type RuntimeEnv = DynamicVarKeyPtrs - -type DexDestructor = FunPtr (IO ()) - -data NativeFunction = NativeFunction - { nativeFunPtr :: FunPtr () - , nativeFunTeardown :: IO () } - -instance Show NativeFunction where - show _ = "" - --- Holds pointers to thread-local storage used to simulate dynamically scoped --- variables, such as the output stream file descriptor. -type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] - -data DynamicVar = OutStreamDyvar -- TODO: add others as needed - deriving (Enum, Bounded) - -dynamicVarCName :: DynamicVar -> String -dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" - -dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] -dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) - --- === bindings - static information we carry about a lexical scope === - --- TODO: consider making this an open union via a typeable-like class -data Binding (c::C) (n::S) where - AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n - TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n - DataConBinding :: TyConName n -> Int -> Binding DataConNameC n - ClassBinding :: ClassDef n -> Binding ClassNameC n - InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n - MethodBinding :: ClassName n -> Int -> Binding MethodNameC n - TopFunBinding :: TopFun n -> Binding TopFunNameC n - FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n - ModuleBinding :: Module n -> Binding ModuleNameC n - -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` - PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n - SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n - ImpNameBinding :: BaseType -> Binding ImpNameC n data EffectOpDef (n::S) where EffectOpDef :: EffectName n -- name of associated effect @@ -749,111 +562,6 @@ instance RenameE EffectOpType deriving instance Show (EffectOpType n) deriving via WrapE EffectOpType n instance Generic (EffectOpType n) -instance GenericE SpecializedDictDef where - type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) - fromE (SpecializedDict ab methods) = ab `PairE` methods' - where methods' = case methods of Just xs -> LeftE (ListE xs) - Nothing -> RightE UnitE - {-# INLINE fromE #-} - toE (ab `PairE` methods) = SpecializedDict ab methods' - where methods' = case methods of LeftE (ListE xs) -> Just xs - RightE UnitE -> Nothing - {-# INLINE toE #-} - -instance SinkableE SpecializedDictDef -instance HoistableE SpecializedDictDef -instance AlphaEqE SpecializedDictDef -instance AlphaHashableE SpecializedDictDef -instance RenameE SpecializedDictDef - -data EvalStatus a = Waiting | Running | Finished a - deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) -type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) - -data TopFun (n::S) = - DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) - | FFITopFun String IFunType - deriving (Show, Generic) - -data TopFunDef (n::S) = - Specialization (SpecializationSpec n) - | LinearizationPrimal (LinearizationSpec n) - -- Tangent functions all take some number of nonlinear args, then a *single* - -- linear arg. This is so that transposition can be an involution - you apply - -- it twice and you get back to the original function. - | LinearizationTangent (LinearizationSpec n) - deriving (Show, Generic) - -newtype TopFunLowerings (n::S) = TopFunLowerings - { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed - deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) - -data AtomBinding (r::IR) (n::S) where - LetBound :: DeclBinding r n -> AtomBinding r n - MiscBound :: Type r n -> AtomBinding r n - TopDataBound :: RepVal SimpIR n -> AtomBinding SimpIR n - SolverBound :: SolverBinding n -> AtomBinding CoreIR n - NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n - FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n - -deriving instance IRRep r => Show (AtomBinding r n) -deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) - --- name of function, name of arg -type InferenceArgDesc = (String, String) -data InfVarDesc = - ImplicitArgInfVar InferenceArgDesc - | AnnotationInfVar String -- name of binder - | TypeInstantiationInfVar String -- name of type - | MiscInfVar - deriving (Show, Generic, Eq, Ord) - -data SolverBinding (n::S) = - InfVarBound (CType n) InfVarCtx - | SkolemBound (CType n) - deriving (Show, Generic) - --- Context for why we created an inference variable. --- This helps us give better "ambiguous variable" errors. -type InfVarCtx = (SrcPosCtx, InfVarDesc) - -newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) - deriving (OutFrag) - -instance HasScope Env where - toScope = toScope . envDefs . topEnv - -instance OutMap Env where - emptyOutMap = - Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) - emptyModuleEnv - {-# INLINE emptyOutMap #-} - -instance ExtOutMap Env (RecSubstFrag Binding) where - -- TODO: We might want to reorganize this struct to make this - -- do less explicit sinking etc. It's a hot operation! - extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = - withExtEvidence frag $ Env - (TopEnv - (defs `extendRecSubst` frag) - (sink rules) - (sink cache) - (sink loadedM) - (sink loadedO)) - (sink moduleEnv) - {-# INLINE extendOutMap #-} - -instance ExtOutMap Env EnvFrag where - extendOutMap = extendEnv - {-# INLINE extendOutMap #-} - -extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l -extendEnv env (EnvFrag newEnv) = do - case extendOutMap env newEnv of - Env envTop (ModuleEnv imports sm scs) -> do - Env envTop (ModuleEnv imports sm scs) -{-# NOINLINE [1] extendEnv #-} - -- === effects === data Effect (r::IR) (n::S) = @@ -910,31 +618,6 @@ instance IRRep r => Store (EffectRowTail r n) instance IRRep r => Store (EffectRow r n) instance IRRep r => Store (Effect r n) --- === Specialization and generalization === - -type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) -type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e -type AbsDict = Abstracted CoreIR Dict - -data SpecializedDictDef n = - SpecializedDict - (AbsDict n) - -- Methods (thunked if nullary), if they're available. - -- We create specialized dict names during simplification, but we don't - -- actually simplify/lower them until we return to TopLevel - (Maybe [TopLam SimpIR n]) - deriving (Show, Generic) - --- TODO: extend with AD-oriented specializations, backend-specific specializations etc. -data SpecializationSpec (n::S) = - AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) - deriving (Show, Generic) - -type Active = Bool -data LinearizationSpec (n::S) = - LinearizationSpec (TopFunName n) [Active] - deriving (Show, Generic) - -- === Binder utils === binderType :: Binder r n l -> Type r n @@ -950,58 +633,95 @@ bindersVars = \case Nest b bs -> withExtEvidence b $ withSubscopeDistinct bs $ sink (binderVar b) : bindersVars bs --- === ToBinding === - -atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n -atomBindingToBinding b = AtomNameBinding b - -bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n -bindingToAtomBinding (AtomNameBinding b) = b - -class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where - toBinding :: e n -> Binding c n - -instance Color c => ToBinding (Binding c) c where - toBinding = id - -instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where - toBinding = atomBindingToBinding - -instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where - toBinding = toBinding . LetBound - -instance IRRep r => ToBinding (Type r) (AtomNameC r) where - toBinding = toBinding . MiscBound - -instance ToBinding SolverBinding (AtomNameC CoreIR) where - toBinding = toBinding . SolverBound - -instance IRRep r => ToBinding (IxType r) (AtomNameC r) where - toBinding (IxType t _) = toBinding t - -instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where - toBinding (LeftE e) = toBinding e - toBinding (RightE e) = toBinding e +-- === ToAtom === + +class ToAtom (e::E) (r::IR) | e -> r where + toAtom :: e n -> Atom r n + +instance ToAtom (Atom r) r where toAtom = id +instance ToAtom (Con r) r where toAtom = Con +instance ToAtom (TyCon CoreIR) CoreIR where toAtom = Con . TyConAtom +instance ToAtom (DictCon CoreIR) CoreIR where toAtom = Con . DictConAtom +instance ToAtom (EffectRow CoreIR) CoreIR where toAtom = Con . Eff +instance ToAtom CoreLamExpr CoreIR where toAtom = Con . Lam +instance ToAtom DictType CoreIR where toAtom = Con . TyConAtom . DictTy +instance ToAtom NewtypeTyCon CoreIR where toAtom = Con . TyConAtom . NewtypeTyCon +instance ToAtom (AtomVar r) r where + toAtom (AtomVar v ty) = Stuck ty (Var (AtomVar v ty)) +instance ToAtom RepVal SimpIR where + toAtom (RepVal ty tree) = Stuck ty $ RepValAtom $ RepVal ty tree +instance ToAtom (Type CoreIR) CoreIR where + toAtom = \case + TyCon con -> Con $ TyConAtom con + StuckTy t s -> Stuck t s +instance ToAtom (Dict CoreIR) CoreIR where + toAtom = \case + DictCon d -> Con $ DictConAtom d + StuckDict t s -> Stuck t s + +-- This can help avoid ambiguous `r` parameter with ToAtom +toAtomR :: ToAtom (e r) r => e r n -> Atom r n +toAtomR = toAtom + +-- === ToType === + +class ToType (e::E) (r::IR) | e -> r where + toType :: e n -> Type r n + +instance ToType (Type r) r where toType = id +instance ToType (TyCon r) r where toType = TyCon +instance ToType (TabPiType r) r where toType = TyCon . TabPi +instance ToType (DepPairType r) r where toType = TyCon . DepPairTy +instance ToType CorePiType CoreIR where toType = TyCon . Pi +instance ToType DictType CoreIR where toType = TyCon . DictTy +instance ToType NewtypeTyCon CoreIR where toType = TyCon . NewtypeTyCon + +toMaybeType :: CAtom n -> Maybe (CType n) +toMaybeType = \case + Stuck t s -> Just $ StuckTy t s + Con (TyConAtom t) -> Just $ TyCon t + _ -> Nothing + +-- === ToDict === + +class ToDict (e::E) (r::IR) | e -> r where + toDict :: e n -> Dict r n + +instance ToDict (Dict r) r where toDict = id +instance ToDict (DictCon r) r where toDict = DictCon +instance ToDict CAtomVar CoreIR where + toDict (AtomVar v ty) = StuckDict ty (Var (AtomVar v ty)) + +toMaybeDict :: CAtom n -> Maybe (CDict n) +toMaybeDict = \case + Stuck t s -> Just $ StuckDict t s + Con (DictConAtom d) -> Just $ DictCon d + _ -> Nothing + +-- === ToExpr === + +class ToExpr (e::E) (r::IR) | e -> r where + toExpr :: e n -> Expr r n + +instance ToExpr (Expr r) r where toExpr = id +instance ToExpr (Atom r) r where toExpr = Atom +instance ToExpr (Con r) r where toExpr = Atom . Con +instance ToExpr (AtomVar r) r where toExpr = toExpr . toAtom +instance ToExpr (PrimOp r) r where toExpr = PrimOp +instance ToExpr (MiscOp r) r where toExpr = PrimOp . MiscOp +instance ToExpr (MemOp r) r where toExpr = PrimOp . MemOp +instance ToExpr (VectorOp r) r where toExpr = PrimOp . VectorOp +instance ToExpr (TypedHof r) r where toExpr = PrimOp . Hof +instance ToExpr (DAMOp SimpIR) SimpIR where toExpr = PrimOp . DAMOp -- === Pattern synonyms === --- XXX: only use this pattern when you're actually expecting a type. If it's --- a Var, it doesn't check whether it's a type. -pattern Type :: CType n -> CAtom n -pattern Type t <- ((\case Var v -> Just (TyVar v) - ProjectElt t i x -> Just $ ProjectEltTy t i x - TypeAsAtom t -> Just t - _ -> Nothing) -> Just t) - where Type (TyVar v) = Var v - Type (ProjectEltTy t i x) = ProjectElt t i x - Type t = TypeAsAtom t - pattern IdxRepScalarBaseTy :: ScalarBaseType pattern IdxRepScalarBaseTy = Word32Type -- Type used to represent indices and sizes at run-time pattern IdxRepTy :: Type r n -pattern IdxRepTy = TC (BaseType (Scalar Word32Type)) +pattern IdxRepTy = TyCon (BaseType (Scalar Word32Type)) pattern IdxRepVal :: Word32 -> Atom r n pattern IdxRepVal x = Con (Lit (Word32Lit x)) @@ -1014,7 +734,7 @@ pattern IIdxRepTy = Scalar Word32Type -- Type used to represent sum type tags at run-time pattern TagRepTy :: Type r n -pattern TagRepTy = TC (BaseType (Scalar Word8Type)) +pattern TagRepTy = TyCon (BaseType (Scalar Word8Type)) pattern TagRepVal :: Word8 -> Atom r n pattern TagRepVal x = Con (Lit (Word8Lit x)) @@ -1026,91 +746,70 @@ charRepVal :: Char -> Atom r n charRepVal c = Con (Lit (Word8Lit (fromIntegral $ fromEnum c))) pattern Word8Ty :: Type r n -pattern Word8Ty = TC (BaseType (Scalar Word8Type)) - -pattern ProdTy :: [Type r n] -> Type r n -pattern ProdTy tys = TC (ProdType tys) - -pattern ProdVal :: [Atom r n] -> Atom r n -pattern ProdVal xs = Con (ProdCon xs) - -pattern SumTy :: [Type r n] -> Type r n -pattern SumTy cs = TC (SumType cs) - -pattern SumVal :: [Type r n] -> Int -> Atom r n -> Atom r n -pattern SumVal tys tag payload = Con (SumCon tys tag payload) +pattern Word8Ty = TyCon (BaseType (Scalar Word8Type)) pattern PairVal :: Atom r n -> Atom r n -> Atom r n pattern PairVal x y = Con (ProdCon [x, y]) pattern PairTy :: Type r n -> Type r n -> Type r n -pattern PairTy x y = TC (ProdType [x, y]) +pattern PairTy x y = TyCon (ProdType [x, y]) pattern UnitVal :: Atom r n pattern UnitVal = Con (ProdCon []) pattern UnitTy :: Type r n -pattern UnitTy = TC (ProdType []) +pattern UnitTy = TyCon (ProdType []) pattern BaseTy :: BaseType -> Type r n -pattern BaseTy b = TC (BaseType b) +pattern BaseTy b = TyCon (BaseType b) pattern PtrTy :: PtrType -> Type r n -pattern PtrTy ty = BaseTy (PtrType ty) +pattern PtrTy ty = TyCon (BaseType (PtrType ty)) pattern RefTy :: Atom r n -> Type r n -> Type r n -pattern RefTy r a = TC (RefType r a) +pattern RefTy r a = TyCon (RefType r a) pattern RawRefTy :: Type r n -> Type r n -pattern RawRefTy a = TC (RefType (Con HeapVal) a) +pattern RawRefTy a = TyCon (RefType (Con HeapVal) a) pattern TabTy :: IxDict r n -> Binder r n l -> Type r l -> Type r n -pattern TabTy d b body = TabPi (TabPiType d b body) +pattern TabTy d b body = TyCon (TabPi (TabPiType d b body)) pattern FinTy :: Atom CoreIR n -> Type CoreIR n -pattern FinTy n = NewtypeTyCon (Fin n) +pattern FinTy n = TyCon (NewtypeTyCon (Fin n)) pattern NatTy :: Type CoreIR n -pattern NatTy = NewtypeTyCon Nat +pattern NatTy = TyCon (NewtypeTyCon Nat) pattern NatVal :: Word32 -> Atom CoreIR n -pattern NatVal n = NewtypeCon NatCon (IdxRepVal n) +pattern NatVal n = Con (NewtypeCon NatCon (IdxRepVal n)) -pattern TyKind :: Kind r n -pattern TyKind = TC TypeKind +pattern TyKind :: Kind CoreIR n +pattern TyKind = TyCon TypeKind pattern EffKind :: Kind CoreIR n -pattern EffKind = NewtypeTyCon EffectRowKind +pattern EffKind = TyCon (NewtypeTyCon EffectRowKind) pattern FinConst :: Word32 -> Type CoreIR n -pattern FinConst n = NewtypeTyCon (Fin (NatVal n)) +pattern FinConst n = TyCon (NewtypeTyCon (Fin (NatVal n))) -pattern NullaryLamExpr :: Block r n -> LamExpr r n +pattern NullaryLamExpr :: Expr r n -> LamExpr r n pattern NullaryLamExpr body = LamExpr Empty body -pattern UnaryLamExpr :: Binder r n l -> Block r l -> LamExpr r n +pattern UnaryLamExpr :: Binder r n l -> Expr r l -> LamExpr r n pattern UnaryLamExpr b body = LamExpr (UnaryNest b) body -pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Block r l2 -> LamExpr r n +pattern BinaryLamExpr :: Binder r n l1 -> Binder r l1 l2 -> Expr r l2 -> LamExpr r n pattern BinaryLamExpr b1 b2 body = LamExpr (BinaryNest b1 b2) body -pattern WithoutDecls :: e n -> WithDecls r e n -pattern WithoutDecls x = Abs Empty x - -exprBlock :: IRRep r => Block r n -> Maybe (Expr r n) -exprBlock (Abs (Nest (Let b (DeclBinding _ expr)) Empty) (Var (AtomVar n _))) - | n == binderName b = Just expr -exprBlock _ = Nothing -{-# INLINE exprBlock #-} - pattern MaybeTy :: Type r n -> Type r n -pattern MaybeTy a = SumTy [UnitTy, a] +pattern MaybeTy a = TyCon (SumType [UnitTy, a]) pattern NothingAtom :: Type r n -> Atom r n -pattern NothingAtom a = SumVal [UnitTy, a] 0 UnitVal +pattern NothingAtom a = Con (SumCon [UnitTy, a] 0 UnitVal) pattern JustAtom :: Type r n -> Atom r n -> Atom r n -pattern JustAtom a x = SumVal [UnitTy, a] 1 x +pattern JustAtom a x = Con (SumCon [UnitTy, a] 1 x) pattern BoolTy :: Type r n pattern BoolTy = Word8Ty @@ -1123,34 +822,16 @@ pattern TrueAtom = Con (Lit (Word8Lit 1)) -- === Typeclass instances for Name and other Haskell libraries === -instance GenericE AtomRules where - type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom - fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a - toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a -instance SinkableE AtomRules -instance HoistableE AtomRules -instance AlphaEqE AtomRules -instance RenameE AtomRules - -instance IRRep r => GenericE (RepVal r) where - type RepE (RepVal r) = PairE (Type r) (ComposeE Tree IExpr) +instance GenericE RepVal where + type RepE RepVal= PairE SType (ComposeE Tree IExpr) fromE (RepVal ty tree) = ty `PairE` ComposeE tree toE (ty `PairE` ComposeE tree) = RepVal ty tree -instance IRRep r => SinkableE (RepVal r) -instance IRRep r => RenameE (RepVal r) -instance IRRep r => HoistableE (RepVal r) -instance IRRep r => AlphaHashableE (RepVal r) -instance IRRep r => AlphaEqE (RepVal r) - -instance GenericE CustomRules where - type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) - fromE (CustomRules m) = ListE $ toPairE <$> M.toList m - toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l -instance SinkableE CustomRules -instance HoistableE CustomRules -instance AlphaEqE CustomRules -instance RenameE CustomRules +instance SinkableE RepVal +instance RenameE RepVal +instance HoistableE RepVal +instance AlphaHashableE RepVal +instance AlphaEqE RepVal instance GenericE TyConParams where type RepE TyConParams = PairE (LiftE [Explicitness]) (ListE CAtom) @@ -1324,7 +1005,6 @@ instance IRRep r => RenameE (DAMOp r) instance IRRep r => AlphaEqE (DAMOp r) instance IRRep r => AlphaHashableE (DAMOp r) -instance IsPrimOp TypedHof where toPrimOp = Hof instance IRRep r => GenericE (TypedHof r) where type RepE (TypedHof r) = EffTy r `PairE` Hof r fromE (TypedHof effTy hof) = effTy `PairE` hof @@ -1342,14 +1022,14 @@ instance IRRep r => GenericE (Hof r) where type RepE (Hof r) = EitherE2 (EitherE6 {- For -} (LiftE ForAnn `PairE` IxType r `PairE` LamExpr r) - {- While -} (Block r) + {- While -} (Expr r) {- RunReader -} (Atom r `PairE` LamExpr r) {- RunWriter -} (MaybeE (Atom r) `PairE` BaseMonoid r `PairE` LamExpr r) {- RunState -} (MaybeE (Atom r) `PairE` Atom r `PairE` LamExpr r) - {- RunIO -} (Block r) + {- RunIO -} (Expr r) ) (EitherE4 - {- RunInit -} (Block r) - {- CatchException -} (WhenCore r (Type r `PairE` Block r)) + {- RunInit -} (Expr r) + {- CatchException -} (WhenCore r (Type r `PairE` Expr r)) {- Linearize -} (WhenCore r (LamExpr r `PairE` Atom r)) {- Transpose -} (WhenCore r (LamExpr r `PairE` Atom r))) @@ -1420,103 +1100,81 @@ instance IRRep r => RenameE (RefOp r) instance IRRep r => AlphaEqE (RefOp r) instance IRRep r => AlphaHashableE (RefOp r) -instance GenericE SimpInCore where - type RepE SimpInCore = EitherE4 - {- LiftSimp -} (CType `PairE` SAtom) - {- LiftSimpFun -} (CorePiType `PairE` LamExpr SimpIR) - {- TabLam -} (TabPiType CoreIR `PairE` TabLamExpr) - {- ACase -} (SAtom `PairE` ListE (Abs SBinder CAtom) `PairE` CType) +instance IRRep r => GenericE (Atom r) where + type RepE (Atom r) = EitherE (PairE (Type r) (Stuck r)) (Con r) fromE = \case - LiftSimp ty x -> Case0 $ ty `PairE` x - LiftSimpFun ty x -> Case1 $ ty `PairE` x - TabLam ty lam -> Case2 $ ty `PairE` lam - ACase scrut alts resultTy -> Case3 $ scrut `PairE` ListE alts `PairE` resultTy + Stuck t x -> LeftE (PairE t x) + Con x -> RightE x {-# INLINE fromE #-} - toE = \case - Case0 (ty `PairE` x) -> LiftSimp ty x - Case1 (ty `PairE` x) -> LiftSimpFun ty x - Case2 (ty `PairE` lam) -> TabLam ty lam - Case3 (x `PairE` ListE alts `PairE` ty) -> ACase x alts ty - _ -> error "impossible" + LeftE (PairE t x) -> Stuck t x + RightE x -> Con x {-# INLINE toE #-} -instance SinkableE SimpInCore -instance HoistableE SimpInCore -instance RenameE SimpInCore -instance AlphaEqE SimpInCore -instance AlphaHashableE SimpInCore +instance IRRep r => SinkableE (Atom r) +instance IRRep r => HoistableE (Atom r) +instance IRRep r => AlphaEqE (Atom r) +instance IRRep r => AlphaHashableE (Atom r) +instance IRRep r => RenameE (Atom r) -instance IRRep r => GenericE (Atom r) where - -- As tempting as it might be to reorder cases here, the current permutation - -- was chosen as to make GHC inliner confident enough to simplify through - -- toE/fromE entirely. If you wish to modify the order, please consult the - -- GHC Core dump to make sure you haven't regressed this optimization. - type RepE (Atom r) = EitherE3 - (EitherE4 - {- Var -} (AtomVar r) - {- ProjectElt -} (Type r `PairE` LiftE Projection `PairE` Atom r) - {- Lam -} (WhenCore r CoreLamExpr) - {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r) - ) (EitherE4 - {- DictCon -} (WhenCore r (CType `PairE` DictExpr)) - {- NewtypeCon -} (WhenCore r (NewtypeCon `PairE` Atom r)) - {- DictHole -} (WhenCore r (LiftE (AlwaysEqual SrcPosCtx) `PairE` - (Type CoreIR) `PairE` - (LiftE RequiredMethodAccess))) - {- Con -} (Con r) - ) (EitherE5 - {- Eff -} ( WhenCore r (EffectRow r)) - {- PtrVar -} (LiftE PtrType `PairE` PtrName) - {- RepValAtom -} ( WhenSimp r (RepVal r)) - {- SimpInCore -} ( WhenCore r SimpInCore) - {- TypeAsAtom -} ( WhenCore r (Type CoreIR)) - ) - - fromE atom = case atom of - Var v -> Case0 (Case0 v) - ProjectElt t idxs x -> Case0 (Case1 (t `PairE` LiftE idxs `PairE` x)) - Lam lamExpr -> Case0 (Case2 (WhenIRE lamExpr)) - DepPair l r ty -> Case0 (Case3 $ l `PairE` r `PairE` ty) - DictCon t d -> Case1 $ Case0 $ WhenIRE $ t `PairE` d - NewtypeCon c x -> Case1 $ Case1 $ WhenIRE (c `PairE` x) - DictHole s t access -> Case1 $ Case2 $ WhenIRE (LiftE s `PairE` t `PairE` LiftE access) - Con con -> Case1 $ Case3 con - Eff effs -> Case2 $ Case0 $ WhenIRE effs - PtrVar t v -> Case2 $ Case1 $ LiftE t `PairE` v - RepValAtom rv -> Case2 $ Case2 $ WhenIRE $ rv - SimpInCore x -> Case2 $ Case3 $ WhenIRE x - TypeAsAtom t -> Case2 $ Case4 $ WhenIRE t +instance IRRep r => GenericE (Stuck r) where + type RepE (Stuck r) = EitherE2 + (EitherE6 + {- Var -} (AtomVar r) + {- StuckProject -} (LiftE Int `PairE` Stuck r) + {- StuckTabApp -} (Stuck r `PairE` Atom r) + {- StuckUnwrap -} (WhenCore r (CStuck)) + {- InstantiatedGiven -} (WhenCore r (CStuck `PairE` ListE CAtom)) + {- SuperclassProj -} (WhenCore r (LiftE Int `PairE` CStuck)) + ) (EitherE6 + {- PtrVar -} (LiftE PtrType `PairE` PtrName) + {- RepValAtom -} (WhenSimp r RepVal) + {- LiftSimp -} (WhenCore r (CType `PairE` SStuck)) + {- LiftSimpFun -} (WhenCore r (CorePiType `PairE` LamExpr SimpIR)) + {- TabLam -} (WhenCore r TabLamExpr) + {- ACase -} (WhenCore r (SStuck `PairE` ListE (Abs SBinder CAtom) `PairE` CType)) + ) + + fromE = \case + Var v -> Case0 $ Case0 v + StuckProject i e -> Case0 $ Case1 $ LiftE i `PairE` e + StuckTabApp f x -> Case0 $ Case2 $ f `PairE` x + StuckUnwrap e -> Case0 $ Case3 $ WhenIRE $ e + InstantiatedGiven e xs -> Case0 $ Case4 $ WhenIRE $ e `PairE` ListE xs + SuperclassProj i e -> Case0 $ Case5 $ WhenIRE $ LiftE i `PairE` e + PtrVar t p -> Case1 $ Case0 $ LiftE t `PairE` p + RepValAtom r -> Case1 $ Case1 $ WhenIRE r + LiftSimp t x -> Case1 $ Case2 $ WhenIRE $ t `PairE` x + LiftSimpFun t lam -> Case1 $ Case3 $ WhenIRE $ t `PairE` lam + TabLam lam -> Case1 $ Case4 $ WhenIRE lam + ACase s alts ty -> Case1 $ Case5 $ WhenIRE $ s `PairE` ListE alts `PairE` ty {-# INLINE fromE #-} - toE atom = case atom of - Case0 val -> case val of - Case0 v -> Var v - Case1 (t `PairE` LiftE idxs `PairE` x) -> ProjectElt t idxs x - Case2 (WhenIRE (lamExpr)) -> Lam lamExpr - Case3 (l `PairE` r `PairE` ty) -> DepPair l r ty - _ -> error "impossible" - Case1 val -> case val of - Case0 (WhenIRE (t `PairE` d)) -> DictCon t d - Case1 (WhenIRE (c `PairE` x)) -> NewtypeCon c x - Case2 (WhenIRE (LiftE s `PairE` t `PairE` LiftE access)) -> DictHole s t access - Case3 con -> Con con + toE = \case + Case0 con -> case con of + Case0 v -> Var v + Case1 (LiftE i `PairE` e) -> StuckProject i e + Case2 (f `PairE` x) -> StuckTabApp f x + Case3 (WhenIRE e) -> StuckUnwrap e + Case4 (WhenIRE (e `PairE` ListE xs)) -> InstantiatedGiven e xs + Case5 (WhenIRE (LiftE i `PairE` e)) -> SuperclassProj i e _ -> error "impossible" - Case2 val -> case val of - Case0 (WhenIRE effs) -> Eff effs - Case1 (LiftE t `PairE` v) -> PtrVar t v - Case2 (WhenIRE rv) -> RepValAtom rv - Case3 (WhenIRE x) -> SimpInCore x - Case4 (WhenIRE t) -> TypeAsAtom t + Case1 con -> case con of + Case0 (LiftE t `PairE` p) -> PtrVar t p + Case1 (WhenIRE r) -> RepValAtom r + Case2 (WhenIRE (t `PairE` x)) -> LiftSimp t x + Case3 (WhenIRE (t `PairE` lam)) -> LiftSimpFun t lam + Case4 (WhenIRE lam) -> TabLam lam + Case5 (WhenIRE (s `PairE` ListE alts `PairE` ty)) -> ACase s alts ty _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} -instance IRRep r => SinkableE (Atom r) -instance IRRep r => HoistableE (Atom r) -instance IRRep r => AlphaEqE (Atom r) -instance IRRep r => AlphaHashableE (Atom r) -instance IRRep r => RenameE (Atom r) +instance IRRep r => SinkableE (Stuck r) +instance IRRep r => HoistableE (Stuck r) +instance IRRep r => AlphaEqE (Stuck r) +instance IRRep r => AlphaHashableE (Stuck r) +instance IRRep r => RenameE (Stuck r) instance IRRep r => GenericE (AtomVar r) where type RepE (AtomVar r) = PairE (AtomName r) (Type r) @@ -1534,7 +1192,6 @@ instance Eq (AtomVar r n) where instance IRRep r => SinkableE (AtomVar r) instance IRRep r => HoistableE (AtomVar r) - -- We ignore the type annotation because it should be determined by the var instance IRRep r => AlphaEqE (AtomVar r) where alphaEqE (AtomVar v _) (AtomVar v' _) = alphaEqE v v' @@ -1546,36 +1203,14 @@ instance IRRep r => AlphaHashableE (AtomVar r) where instance IRRep r => RenameE (AtomVar r) instance IRRep r => GenericE (Type r) where - type RepE (Type r) = EitherE8 - {- TyVar -} (WhenCore r CAtomVar) - {- Pi -} (WhenCore r CorePiType) - {- TabPi -} (TabPiType r) - {- DepPairTy -} (DepPairType r) - {- DictTy -} (WhenCore r DictType) - {- NewtypeTyCon -} (WhenCore r NewtypeTyCon) - {- TC -} (TC r) - {- ProjectEltTy -} (WhenCore r (Type r `PairE` LiftE Projection `PairE` Atom r)) - + type RepE (Type r) = EitherE (WhenCore r (PairE (Type r) (Stuck r))) (TyCon r) fromE = \case - TyVar v -> Case0 $ WhenIRE v - Pi t -> Case1 $ WhenIRE t - TabPi t -> Case2 t - DepPairTy t -> Case3 t - DictTy d -> Case4 $ WhenIRE d - NewtypeTyCon t -> Case5 $ WhenIRE t - TC con -> Case6 $ con - ProjectEltTy t idxs x -> Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) + StuckTy t x -> LeftE (WhenIRE (PairE t x)) + TyCon x -> RightE x {-# INLINE fromE #-} - toE = \case - Case0 (WhenIRE v) -> TyVar v - Case1 (WhenIRE t) -> Pi t - Case2 t -> TabPi t - Case3 t -> DepPairTy t - Case4 (WhenIRE d) -> DictTy d - Case5 (WhenIRE t) -> NewtypeTyCon t - Case6 con -> TC con - Case7 (WhenIRE (t `PairE` LiftE idxs `PairE` x)) -> ProjectEltTy t idxs x + LeftE (WhenIRE (PairE t x)) -> StuckTy t x + RightE x -> TyCon x {-# INLINE toE #-} instance IRRep r => SinkableE (Type r) @@ -1586,40 +1221,48 @@ instance IRRep r => RenameE (Type r) instance IRRep r => GenericE (Expr r) where type RepE (Expr r) = EitherE2 - ( EitherE5 + ( EitherE6 {- App -} (WhenCore r (EffTy r `PairE` Atom r `PairE` ListE (Atom r))) - {- TabApp -} (Type r `PairE` Atom r `PairE` ListE (Atom r)) + {- TabApp -} (Type r `PairE` Atom r `PairE` Atom r) {- Case -} (Atom r `PairE` ListE (Alt r) `PairE` EffTy r) {- Atom -} (Atom r) {- TopApp -} (WhenSimp r (EffTy r `PairE` TopFunName `PairE` ListE (Atom r))) + {- Block -} (EffTy r `PairE` Block r) ) - ( EitherE3 - {- TabCon -} (MaybeE (WhenCore r Dict) `PairE` Type r `PairE` ListE (Atom r)) + ( EitherE5 + {- TabCon -} (MaybeE (WhenCore r (Dict CoreIR)) `PairE` Type r `PairE` ListE (Atom r)) {- PrimOp -} (PrimOp r) - {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r)))) - + {- ApplyMethod -} (WhenCore r (EffTy r `PairE` Atom r `PairE` LiftE Int `PairE` ListE (Atom r))) + {- Project -} (Type r `PairE` LiftE Int `PairE` Atom r) + {- Unwrap -} (WhenCore r (CType `PairE` CAtom))) fromE = \case App et f xs -> Case0 $ Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) - TabApp t f xs -> Case0 $ Case1 (t `PairE` f `PairE` ListE xs) + TabApp t f x -> Case0 $ Case1 (t `PairE` f `PairE` x) Case e alts effTy -> Case0 $ Case2 (e `PairE` ListE alts `PairE` effTy) Atom x -> Case0 $ Case3 (x) - TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) - TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) - PrimOp op -> Case1 $ Case1 op + TopApp et f xs -> Case0 $ Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) + Block et block -> Case0 $ Case5 (et `PairE` block) + TabCon d ty xs -> Case1 $ Case0 (toMaybeE d `PairE` ty `PairE` ListE xs) + PrimOp op -> Case1 $ Case1 op ApplyMethod et d i xs -> Case1 $ Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) + Project ty i x -> Case1 $ Case3 (ty `PairE` LiftE i `PairE` x) + Unwrap t x -> Case1 $ Case4 (WhenIRE (t `PairE` x)) {-# INLINE fromE #-} toE = \case Case0 case0 -> case case0 of Case0 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> App et f xs - Case1 (t `PairE` f `PairE` ListE xs) -> TabApp t f xs - Case2 (e `PairE` ListE alts `PairE` effTy) -> Case e alts effTy + Case1 (t `PairE` f `PairE` x) -> TabApp t f x + Case2 (e `PairE` ListE alts `PairE` effTy) -> Case e alts effTy Case3 (x) -> Atom x Case4 (WhenIRE (et `PairE` f `PairE` ListE xs)) -> TopApp et f xs + Case5 (et `PairE` block) -> Block et block _ -> error "impossible" Case1 case1 -> case case1 of Case0 (d `PairE` ty `PairE` ListE xs) -> TabCon (fromMaybeE d) ty xs Case1 op -> PrimOp op Case2 (WhenIRE (et `PairE` d `PairE` LiftE i `PairE` ListE xs)) -> ApplyMethod et d i xs + Case3 (ty `PairE` LiftE i `PairE` x) -> Project ty i x + Case4 (WhenIRE (t `PairE` x)) -> Unwrap t x _ -> error "impossible" _ -> error "impossible" {-# INLINE toE #-} @@ -1693,7 +1336,6 @@ instance GenericOp VectorOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp VectorOp where toPrimOp = VectorOp instance IRRep r => GenericE (VectorOp r) where type RepE (VectorOp r) = GenericOpRep (OpConst VectorOp r) r fromE = fromEGenericOpRep @@ -1722,7 +1364,6 @@ instance GenericOp MemOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp MemOp where toPrimOp = MemOp instance IRRep r => GenericE (MemOp r) where type RepE (MemOp r) = GenericOpRep (OpConst MemOp r) r fromE = fromEGenericOpRep @@ -1765,7 +1406,6 @@ instance GenericOp MiscOp where _ -> Nothing {-# INLINE toOp #-} -instance IsPrimOp MiscOp where toPrimOp = MiscOp instance IRRep r => GenericE (MiscOp r) where type RepE (MiscOp r) = GenericOpRep (OpConst MiscOp r) r fromE = fromEGenericOpRep @@ -1776,27 +1416,49 @@ instance IRRep r => AlphaEqE (MiscOp r) instance IRRep r => AlphaHashableE (MiscOp r) instance IRRep r => RenameE (MiscOp r) -instance GenericOp Con where - type OpConst Con r = Either LitVal P.Con - fromOp = \case - Lit l -> GenericOpRep (Left l) [] [] [] - ProdCon xs -> GenericOpRep (Right P.ProdCon) [] xs [] - SumCon tys i x -> GenericOpRep (Right (P.SumCon i)) tys [x] [] - HeapVal -> GenericOpRep (Right P.HeapVal) [] [] [] - {-# INLINE fromOp #-} - - toOp = \case - GenericOpRep (Left l) [] [] [] -> Just $ Lit l - GenericOpRep (Right P.ProdCon) [] xs [] -> Just $ ProdCon xs - GenericOpRep (Right (P.SumCon i)) tys [x] [] -> Just $ SumCon tys i x - GenericOpRep (Right P.HeapVal) [] [] [] -> Just $ HeapVal - _ -> Nothing - {-# INLINE toOp #-} - instance IRRep r => GenericE (Con r) where - type RepE (Con r) = GenericOpRep (OpConst Con r) r - fromE = fromEGenericOpRep - toE = toEGenericOpRep + type RepE (Con r) = EitherE2 + (EitherE5 + {- Lit -} (LiftE LitVal) + {- ProdCon -} (ListE (Atom r)) + {- SumCon -} (ListE (Type r) `PairE` LiftE Int `PairE` Atom r) + {- HeapVal -} UnitE + {- DepPair -} (Atom r `PairE` Atom r `PairE` DepPairType r)) + (WhenCore r (EitherE5 + {- Lam -} CoreLamExpr + {- Eff -} (EffectRow CoreIR) + {- NewtypeCon -} (NewtypeCon `PairE` CAtom) + {- DictConAtom -} (DictCon CoreIR) + {- TyConAtom -} (TyCon CoreIR))) + fromE = \case + Lit l -> Case0 $ Case0 $ LiftE l + ProdCon xs -> Case0 $ Case1 $ ListE xs + SumCon ts i x -> Case0 $ Case2 $ ListE ts `PairE` LiftE i `PairE` x + HeapVal -> Case0 $ Case3 $ UnitE + DepPair x y t -> Case0 $ Case4 $ x `PairE` y `PairE` t + Lam lam -> Case1 $ WhenIRE $ Case0 lam + Eff effs -> Case1 $ WhenIRE $ Case1 effs + NewtypeCon con x -> Case1 $ WhenIRE $ Case2 $ con `PairE` x + DictConAtom con -> Case1 $ WhenIRE $ Case3 con + TyConAtom tc -> Case1 $ WhenIRE $ Case4 tc + {-# INLINE fromE #-} + toE = \case + Case0 con -> case con of + Case0 (LiftE l) -> Lit l + Case1 (ListE xs) -> ProdCon xs + Case2 (ListE ts `PairE` LiftE i `PairE` x) -> SumCon ts i x + Case3 UnitE -> HeapVal + Case4 (x `PairE` y `PairE` t) -> DepPair x y t + _ -> error "impossible" + Case1 (WhenIRE con) -> case con of + Case0 lam -> Lam lam + Case1 effs -> Eff effs + Case2 (con' `PairE` x) -> NewtypeCon con' x + Case3 con' -> DictConAtom con' + Case4 tc -> TyConAtom tc + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} instance IRRep r => SinkableE (Con r) instance IRRep r => HoistableE (Con r) @@ -1804,36 +1466,61 @@ instance IRRep r => AlphaEqE (Con r) instance IRRep r => AlphaHashableE (Con r) instance IRRep r => RenameE (Con r) -instance GenericOp TC where - type OpConst TC r = Either BaseType P.TC - fromOp = \case - BaseType b -> GenericOpRep (Left b) [] [] [] - ProdType ts -> GenericOpRep (Right P.ProdType) ts [] [] - SumType ts -> GenericOpRep (Right P.SumType) ts [] [] - RefType h t -> GenericOpRep (Right P.RefType) [t] [h] [] - TypeKind -> GenericOpRep (Right P.TypeKind) [] [] [] - HeapType -> GenericOpRep (Right P.HeapType) [] [] [] - {-# INLINE fromOp #-} - - toOp = \case - GenericOpRep (Left b) [] [] [] -> Just (BaseType b) - GenericOpRep (Right P.ProdType) ts [] [] -> Just (ProdType ts) - GenericOpRep (Right P.SumType) ts [] [] -> Just (SumType ts) - GenericOpRep (Right P.RefType) [t] [h] [] -> Just (RefType h t) - GenericOpRep (Right P.TypeKind) [] [] [] -> Just TypeKind - GenericOpRep (Right P.HeapType) [] [] [] -> Just HeapType - GenericOpRep _ _ _ _ -> Nothing - {-# INLINE toOp #-} +instance IRRep r => GenericE (TyCon r) where + type RepE (TyCon r) = EitherE3 + (EitherE4 + {- BaseType -} (LiftE BaseType) + {- ProdType -} (ListE (Type r)) + {- SumType -} (ListE (Type r)) + {- RefType -} (Atom r `PairE` Type r)) + (EitherE4 + {- HeapType -} UnitE + {- TabPi -} (TabPiType r) + {- DepPairTy -} (DepPairType r) + {- TypeKind -} (WhenCore r UnitE)) + (EitherE3 + {- DictTy -} (WhenCore r DictType) + {- Pi -} (WhenCore r CorePiType) + {- NewtypeTyCon -} (WhenCore r NewtypeTyCon)) + fromE = \case + BaseType b -> Case0 (Case0 (LiftE b)) + ProdType ts -> Case0 (Case1 (ListE ts)) + SumType ts -> Case0 (Case2 (ListE ts)) + RefType h t -> Case0 (Case3 (h `PairE` t)) + HeapType -> Case1 (Case0 UnitE) + TabPi t -> Case1 (Case1 t) + DepPairTy t -> Case1 (Case2 t) + TypeKind -> Case1 (Case3 (WhenIRE UnitE)) + DictTy t -> Case2 (Case0 (WhenIRE t)) + Pi t -> Case2 (Case1 (WhenIRE t)) + NewtypeTyCon t -> Case2 (Case2 (WhenIRE t)) + {-# INLINE fromE #-} + toE = \case + Case0 c -> case c of + Case0 (LiftE b ) -> BaseType b + Case1 (ListE ts) -> ProdType ts + Case2 (ListE ts) -> SumType ts + Case3 (h `PairE` t) -> RefType h t + _ -> error "impossible" + Case1 c -> case c of + Case0 UnitE -> HeapType + Case1 t -> TabPi t + Case2 t -> DepPairTy t + Case3 (WhenIRE UnitE) -> TypeKind + _ -> error "impossible" + Case2 c -> case c of + Case0 (WhenIRE t) -> DictTy t + Case1 (WhenIRE t) -> Pi t + Case2 (WhenIRE t) -> NewtypeTyCon t + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} -instance IRRep r => GenericE (TC r) where - type RepE (TC r) = GenericOpRep (OpConst TC r) r - fromE = fromEGenericOpRep - toE = toEGenericOpRep -instance IRRep r => SinkableE (TC r) -instance IRRep r => HoistableE (TC r) -instance IRRep r => AlphaEqE (TC r) -instance IRRep r => AlphaHashableE (TC r) -instance IRRep r => RenameE (TC r) +instance IRRep r => SinkableE (TyCon r) +instance IRRep r => HoistableE (TyCon r) +instance IRRep r => AlphaEqE (TyCon r) +instance IRRep r => AlphaHashableE (TyCon r) +instance IRRep r => RenameE (TyCon r) instance IRRep r => GenericB (NonDepNest r ann) where type RepB (NonDepNest r ann) = (LiftB (ListE ann)) `PairB` Nest (AtomNameBinder r) @@ -1852,13 +1539,13 @@ deriving instance (Show (ann n)) => IRRep r => Show (NonDepNest r ann n l) instance GenericE ClassDef where type RepE ClassDef = - LiftE (SourceName, [SourceName], [Maybe SourceName], [RoleExpl]) + LiftE (SourceName, Maybe BuiltinClassName, [SourceName], [Maybe SourceName], [RoleExpl]) `PairE` Abs (Nest CBinder) (Abs (Nest CBinder) (ListE CorePiType)) - fromE (ClassDef name names paramNames roleExpls b scs tys) = - LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) + fromE (ClassDef name builtin names paramNames roleExpls b scs tys) = + LiftE (name, builtin, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys)) {-# INLINE fromE #-} - toE (LiftE (name, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = - ClassDef name names paramNames roleExpls b scs tys + toE (LiftE (name, builtin, names, paramNames, roleExpls) `PairE` Abs b (Abs scs (ListE tys))) = + ClassDef name builtin names paramNames roleExpls b scs tys {-# INLINE toE #-} instance SinkableE ClassDef @@ -1869,7 +1556,7 @@ instance RenameE ClassDef deriving instance Show (ClassDef n) deriving via WrapE ClassDef n instance Generic (ClassDef n) instance HasSourceName (ClassDef n) where - getSourceName = \case ClassDef name _ _ _ _ _ _ -> name + getSourceName = \case ClassDef name _ _ _ _ _ _ _ -> name instance GenericE InstanceDef where type RepE InstanceDef = @@ -1899,11 +1586,19 @@ instance AlphaHashableE InstanceBody instance RenameE InstanceBody instance GenericE DictType where - type RepE DictType = LiftE SourceName `PairE` ClassName `PairE` ListE CAtom - fromE (DictType sourceName className params) = - LiftE sourceName `PairE` className `PairE` ListE params - toE (LiftE sourceName `PairE` className `PairE` ListE params) = - DictType sourceName className params + type RepE DictType = EitherE3 + {- DictType -} (LiftE SourceName `PairE` ClassName `PairE` ListE CAtom) + {- IxDictType -} CType + {- DataDictType -} CType + fromE = \case + DictType sourceName className params -> Case0 $ LiftE sourceName `PairE` className `PairE` ListE params + IxDictType ty -> Case1 ty + DataDictType ty -> Case2 ty + toE = \case + Case0 (LiftE sourceName `PairE` className `PairE` ListE params) -> DictType sourceName className params + Case1 ty -> IxDictType ty + Case2 ty -> DataDictType ty + _ -> error "impossible" instance SinkableE DictType instance HoistableE DictType @@ -1911,75 +1606,52 @@ instance AlphaEqE DictType instance AlphaHashableE DictType instance RenameE DictType -instance GenericE DictExpr where - type RepE DictExpr = - EitherE5 - {- InstanceDict -} (PairE InstanceName (ListE CAtom)) - {- InstantiatedGiven -} (PairE CAtom (ListE CAtom)) - {- SuperclassProj -} (PairE CAtom (LiftE Int)) - {- IxFin -} CAtom - {- DataData -} CType - fromE d = case d of - InstanceDict v args -> Case0 $ PairE v (ListE args) - InstantiatedGiven given args -> Case1 $ PairE given (ListE args) - SuperclassProj x i -> Case2 (PairE x (LiftE i)) - IxFin x -> Case3 x - DataData ty -> Case4 ty - toE d = case d of - Case0 (PairE v (ListE args)) -> InstanceDict v args - Case1 (PairE given (ListE args)) -> InstantiatedGiven given args - Case2 (PairE x (LiftE i)) -> SuperclassProj x i - Case3 x -> IxFin x - Case4 ty -> DataData ty - _ -> error "impossible" - -instance SinkableE DictExpr -instance HoistableE DictExpr -instance AlphaEqE DictExpr -instance AlphaHashableE DictExpr -instance RenameE DictExpr - -instance GenericE Cache where - type RepE Cache = - EMap SpecializationSpec TopFunName - `PairE` EMap AbsDict SpecDictName - `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) - `PairE` EMap TopFunName TopFunName - `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) - `PairE` ListE ( LiftE ModuleSourceName - `PairE` LiftE FileHash - `PairE` ListE ModuleName - `PairE` ModuleName) - fromE (Cache x y z w parseCache evalCache) = - x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` - ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] +instance IRRep r => GenericE (Dict r) where + type RepE (Dict r) = EitherE (WhenCore r (PairE (Type r) (Stuck r))) (DictCon r) + fromE = \case + StuckDict t d -> LeftE (WhenIRE (PairE t d)) + DictCon d -> RightE d {-# INLINE fromE #-} - toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = - Cache x y z w parseCache - (M.fromList - [(sourceName, ((hashVal, deps), result)) - | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result - <- evalCache]) + toE = \case + LeftE (WhenIRE (PairE t d)) -> StuckDict t d + RightE d -> DictCon d {-# INLINE toE #-} -instance SinkableE Cache -instance HoistableE Cache -instance AlphaEqE Cache -instance RenameE Cache -instance Store (Cache n) - -instance Monoid (Cache n) where - mempty = Cache mempty mempty mempty mempty mempty mempty - mappend = (<>) +instance IRRep r => SinkableE (Dict r) +instance IRRep r => HoistableE (Dict r) +instance IRRep r => AlphaEqE (Dict r) +instance IRRep r => AlphaHashableE (Dict r) +instance IRRep r => RenameE (Dict r) + +instance IRRep r => GenericE (DictCon r) where + type RepE (DictCon r) = EitherE5 + {- InstanceDict -} (WhenCore r (CType `PairE` PairE InstanceName (ListE CAtom))) + {- IxFin -} (WhenCore r CAtom) + {- DataData -} (WhenCore r CType) + {- IxRawFin -} (Atom r) + {- IxSpecialized -} (WhenSimp r (SpecDictName `PairE` ListE SAtom)) + fromE = \case + InstanceDict t v args -> Case0 $ WhenIRE $ t `PairE` PairE v (ListE args) + IxFin x -> Case1 $ WhenIRE $ x + DataData ty -> Case2 $ WhenIRE $ ty + IxRawFin n -> Case3 $ n + IxSpecialized d xs -> Case4 $ WhenIRE $ d `PairE` ListE xs + toE = \case + Case0 (WhenIRE (t `PairE` (PairE v (ListE args)))) -> InstanceDict t v args + Case1 (WhenIRE x) -> IxFin x + Case2 (WhenIRE ty) -> DataData ty + Case3 n -> IxRawFin n + Case4 (WhenIRE (d `PairE` ListE xs)) -> IxSpecialized d xs + _ -> error "impossible" -instance Semigroup (Cache n) where - -- right-biased instead of left-biased - Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = - Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) +instance IRRep r => SinkableE (DictCon r) +instance IRRep r => HoistableE (DictCon r) +instance IRRep r => AlphaEqE (DictCon r) +instance IRRep r => AlphaHashableE (DictCon r) +instance IRRep r => RenameE (DictCon r) instance GenericE (LamExpr r) where - type RepE (LamExpr r) = Abs (Nest (Binder r)) (Block r) + type RepE (LamExpr r) = Abs (Nest (Binder r)) (Expr r) fromE (LamExpr b block) = Abs b block {-# INLINE fromE #-} toE (Abs b block) = LamExpr b block @@ -2005,8 +1677,6 @@ instance HoistableE CoreLamExpr instance AlphaEqE CoreLamExpr instance AlphaHashableE CoreLamExpr instance RenameE CoreLamExpr -deriving instance Show (CoreLamExpr n) -deriving via WrapE CoreLamExpr n instance Generic (CoreLamExpr n) instance GenericE CorePiType where type RepE CorePiType = LiftE (AppExplicitness, [Explicitness]) `PairE` Abs (Nest CBinder) (EffTy CoreIR) @@ -2023,30 +1693,6 @@ instance RenameE CorePiType deriving instance Show (CorePiType n) deriving via WrapE CorePiType n instance Generic (CorePiType n) -instance IRRep r => GenericE (IxDict r) where - type RepE (IxDict r) = - EitherE3 - (WhenCore r (Atom r)) - (Atom r) - (WhenSimp r (Type r `PairE` SpecDictName `PairE` ListE (Atom r))) - fromE = \case - IxDictAtom x -> Case0 $ WhenIRE x - IxDictRawFin n -> Case1 $ n - IxDictSpecialized t d xs -> Case2 $ WhenIRE $ t `PairE` d `PairE` ListE xs - {-# INLINE fromE #-} - toE = \case - Case0 (WhenIRE x) -> IxDictAtom x - Case1 (n) -> IxDictRawFin n - Case2 (WhenIRE (t `PairE` d `PairE` ListE xs)) -> IxDictSpecialized t d xs - _ -> error "impossible" - {-# INLINE toE #-} - -instance IRRep r => SinkableE (IxDict r) -instance IRRep r => HoistableE (IxDict r) -instance IRRep r => RenameE (IxDict r) -instance IRRep r => AlphaEqE (IxDict r) -instance IRRep r => AlphaHashableE (IxDict r) - instance IRRep r => GenericE (IxType r) where type RepE (IxType r) = PairE (Type r) (IxDict r) fromE (IxType ty d) = PairE ty d @@ -2118,225 +1764,6 @@ instance IRRep r => RenameE (DepPairType r) deriving instance IRRep r => Show (DepPairType r n) deriving via WrapE (DepPairType r) n instance IRRep r => Generic (DepPairType r n) -instance GenericE SynthCandidates where - type RepE SynthCandidates = - ListE (AtomName CoreIR) `PairE` ListE (PairE ClassName (ListE InstanceName)) - fromE (SynthCandidates xs ys) = ListE xs `PairE` ListE ys' - where ys' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList ys) - {-# INLINE fromE #-} - toE (ListE xs `PairE` ListE ys) = SynthCandidates xs ys' - where ys' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) ys - {-# INLINE toE #-} - -instance SinkableE SynthCandidates -instance HoistableE SynthCandidates -instance AlphaEqE SynthCandidates -instance AlphaHashableE SynthCandidates -instance RenameE SynthCandidates - -instance IRRep r => GenericE (AtomBinding r) where - type RepE (AtomBinding r) = - EitherE2 (EitherE3 - (DeclBinding r) -- LetBound - (Type r) -- MiscBound - (WhenCore r SolverBinding) -- SolverBound - ) (EitherE3 - (WhenCore r (PairE CType CAtom)) -- NoinlineFun - (WhenSimp r (RepVal SimpIR)) -- TopDataBound - (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound - ) - - fromE = \case - LetBound x -> Case0 $ Case0 x - MiscBound x -> Case0 $ Case1 x - SolverBound x -> Case0 $ Case2 $ WhenIRE x - NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x - TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal - FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v - {-# INLINE fromE #-} - - toE = \case - Case0 x' -> case x' of - Case0 x -> LetBound x - Case1 x -> MiscBound x - Case2 (WhenIRE x) -> SolverBound x - _ -> error "impossible" - Case1 x' -> case x' of - Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x - Case1 (WhenIRE repVal) -> TopDataBound repVal - Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v - _ -> error "impossible" - _ -> error "impossible" - {-# INLINE toE #-} - - -instance IRRep r => SinkableE (AtomBinding r) -instance IRRep r => HoistableE (AtomBinding r) -instance IRRep r => RenameE (AtomBinding r) -instance IRRep r => AlphaEqE (AtomBinding r) -instance IRRep r => AlphaHashableE (AtomBinding r) - -instance GenericE TopFunDef where - type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec - fromE = \case - Specialization s -> Case0 s - LinearizationPrimal s -> Case1 s - LinearizationTangent s -> Case2 s - {-# INLINE fromE #-} - toE = \case - Case0 s -> Specialization s - Case1 s -> LinearizationPrimal s - Case2 s -> LinearizationTangent s - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE TopFunDef -instance HoistableE TopFunDef -instance RenameE TopFunDef -instance AlphaEqE TopFunDef -instance AlphaHashableE TopFunDef - -instance IRRep r => GenericE (TopLam r) where - type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r - fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y - {-# INLINE fromE #-} - toE (LiftE d `PairE` x `PairE` y) = TopLam d x y - {-# INLINE toE #-} - -instance IRRep r => SinkableE (TopLam r) -instance IRRep r => HoistableE (TopLam r) -instance IRRep r => RenameE (TopLam r) -instance IRRep r => AlphaEqE (TopLam r) -instance IRRep r => AlphaHashableE (TopLam r) - -instance GenericE TopFun where - type RepE TopFun = EitherE - (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) - (LiftE (String, IFunType)) - fromE = \case - DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) - FFITopFun name ty -> RightE (LiftE (name, ty)) - {-# INLINE fromE #-} - toE = \case - LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status - RightE (LiftE (name, ty)) -> FFITopFun name ty - {-# INLINE toE #-} - -instance SinkableE TopFun -instance HoistableE TopFun -instance RenameE TopFun -instance AlphaEqE TopFun -instance AlphaHashableE TopFun - -instance GenericE SpecializationSpec where - type RepE SpecializationSpec = - PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) - fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) - {-# INLINE fromE #-} - toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) - {-# INLINE toE #-} - -instance HasNameHint (SpecializationSpec n) where - getNameHint (AppSpecialization f _) = getNameHint f - -instance SinkableE SpecializationSpec -instance HoistableE SpecializationSpec -instance RenameE SpecializationSpec -instance AlphaEqE SpecializationSpec -instance AlphaHashableE SpecializationSpec - -instance GenericE LinearizationSpec where - type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) - fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) - {-# INLINE fromE #-} - toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives - {-# INLINE toE #-} - -instance SinkableE LinearizationSpec -instance HoistableE LinearizationSpec -instance RenameE LinearizationSpec -instance AlphaEqE LinearizationSpec -instance AlphaHashableE LinearizationSpec - -instance GenericE SolverBinding where - type RepE SolverBinding = EitherE2 - (PairE CType (LiftE InfVarCtx)) - CType - fromE = \case - InfVarBound ty ctx -> Case0 (PairE ty (LiftE ctx)) - SkolemBound ty -> Case1 ty - {-# INLINE fromE #-} - - toE = \case - Case0 (PairE ty (LiftE ct)) -> InfVarBound ty ct - Case1 ty -> SkolemBound ty - _ -> error "impossible" - {-# INLINE toE #-} - -instance SinkableE SolverBinding -instance HoistableE SolverBinding -instance RenameE SolverBinding -instance AlphaEqE SolverBinding -instance AlphaHashableE SolverBinding - -instance GenericE (Binding c) where - type RepE (Binding c) = - EitherE3 - (EitherE6 - (WhenAtomName c AtomBinding) - (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) - (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) - (WhenC ClassNameC c (ClassDef)) - (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) - (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) - (EitherE4 - (WhenC TopFunNameC c (TopFun)) - (WhenC FunObjCodeNameC c (CFunction)) - (WhenC ModuleNameC c (Module)) - (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) - (EitherE2 - (WhenC SpecializedDictNameC c (SpecializedDictDef)) - (WhenC ImpNameC c (LiftE BaseType))) - - fromE = \case - AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding - TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods - DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx - ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef - InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty - MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx - TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun - FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun - ModuleBinding m -> Case1 $ Case2 $ WhenC $ m - PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) - SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def - ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty - {-# INLINE fromE #-} - - toE = \case - Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding - Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods - Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx - Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef - Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty - Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i - Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun - Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f - Case1 (Case2 (WhenC (m))) -> ModuleBinding m - Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p - Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def - Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty - _ -> error "impossible" - {-# INLINE toE #-} - -deriving via WrapE (Binding c) n instance Generic (Binding c n) -instance SinkableV Binding -instance HoistableV Binding -instance RenameV Binding -instance Color c => SinkableE (Binding c) -instance Color c => HoistableE (Binding c) -instance Color c => RenameE (Binding c) - instance GenericE DotMethods where type RepE DotMethods = ListE (LiftE SourceName `PairE` CAtomName) fromE (DotMethods xys) = ListE $ [LiftE x `PairE` y | (x, y) <- M.toList xys] @@ -2454,338 +1881,417 @@ instance IRRep r => BindsOneName (Decl r) (AtomNameC r) where binderName (Let b _) = binderName b {-# INLINE binderName #-} -instance Semigroup (SynthCandidates n) where - SynthCandidates xs ys <> SynthCandidates xs' ys' = - SynthCandidates (xs<>xs') (M.unionWith (<>) ys ys') - -instance Monoid (SynthCandidates n) where - mempty = SynthCandidates mempty mempty - -instance GenericB EnvFrag where - type RepB EnvFrag = RecSubstFrag Binding - fromB (EnvFrag frag) = frag - toB frag = EnvFrag frag - -instance SinkableB EnvFrag -instance HoistableB EnvFrag -instance ProvesExt EnvFrag -instance BindsNames EnvFrag -instance RenameB EnvFrag - -instance GenericE TopEnvUpdate where - type RepE TopEnvUpdate = EitherE2 ( - EitherE4 - {- ExtendCache -} Cache - {- AddCustomRule -} (CAtomName `PairE` AtomRules) - {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) - {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) - ) ( EitherE6 - {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) - {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) - {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) - {- UpdateTyConDef -} (TyConName `PairE` TyConDef) - {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) - ) - fromE = \case - ExtendCache x -> Case0 $ Case0 x - AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) - UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) - UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) - FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) - LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) - UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) - UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) - UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) - UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) - - toE = \case - Case0 e -> case e of - Case0 x -> ExtendCache x - Case1 (x `PairE` y) -> AddCustomRule x y - Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y - Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y - _ -> error "impossible" - Case1 e -> case e of - Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y - Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y - Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y - Case3 (x `PairE` y) -> UpdateInstanceDef x y - Case4 (x `PairE` y) -> UpdateTyConDef x y - Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z - _ -> error "impossible" - _ -> error "impossible" - -instance SinkableE TopEnvUpdate -instance HoistableE TopEnvUpdate -instance RenameE TopEnvUpdate - -instance GenericB TopEnvFrag where - type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) - fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) - toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) - -instance RenameB TopEnvFrag -instance HoistableB TopEnvFrag -instance SinkableB TopEnvFrag -instance ProvesExt TopEnvFrag -instance BindsNames TopEnvFrag - -instance OutFrag TopEnvFrag where - emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty - {-# INLINE emptyOutFrag #-} - catOutFrags (TopEnvFrag frag1 env1 partial1) - (TopEnvFrag frag2 env2 partial2) = - withExtEvidence frag2 $ - TopEnvFrag - (catOutFrags frag1 frag2) - (sink env1 <> env2) - (sinkSnocList partial1 <> partial2) - {-# INLINE catOutFrags #-} - --- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't --- extend the synthesis candidates based on the annotated let-bound names. It --- only extends synth candidates when they're supplied explicitly. -instance ExtOutMap Env TopEnvFrag where - extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do - let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates - Env newerTopEnv newModuleEnv - where - Env (TopEnv defs rules cache loadedM loadedO) mEnv = env - - newTopEnv = withExtEvidence frag $ TopEnv - (defs `extendRecSubst` frag) - (sink rules) (sink cache) (sink loadedM) (sink loadedO) - - newModuleEnv = - ModuleEnv - (imports <> imports') - (sm <> sm' <> newImportedSM) - (scs <> scs' <> newImportedSC) - where - ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv - ModuleEnv imports' sm' scs' = mEnv' - newDirectImports = S.difference (directImports imports') (directImports imports) - newTransImports = S.difference (transImports imports') (transImports imports) - newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure - newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure - - lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m - -applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n -applyUpdate e = \case - ExtendCache cache -> e { envCache = envCache e <> cache} - AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} - UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} - UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} - FinishDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName - case oldMethods of - Nothing -> do - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - Just _ -> error "shouldn't be adding methods if we already have them" - LowerDictSpecialization dName methods -> do - let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName - let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) - updateEnv dName newBinding e - UpdateTopFunEvalStatus f s -> do - case lookupEnvPure e f of - TopFunBinding (DexTopFun def lam _) -> - updateEnv f (TopFunBinding $ DexTopFun def lam s) e - _ -> error "can't update ffi function impl" - UpdateInstanceDef name def -> do - case lookupEnvPure e name of - InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e - UpdateTyConDef name def -> do - let TyConBinding _ methods = lookupEnvPure e name - updateEnv name (TyConBinding (Just def) methods) e - UpdateFieldDef name sn x -> do - let TyConBinding def methods = lookupEnvPure e name - updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e - -updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n -updateEnv v rhs env = - env { envDefs = RecSubst $ updateSubstFrag v rhs bs } - where (RecSubst bs) = envDefs env - -lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n -lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v - -instance GenericE Module where - type RepE Module = LiftE ModuleSourceName - `PairE` ListE ModuleName - `PairE` ListE ModuleName - `PairE` SourceMap - `PairE` SynthCandidates - - fromE (Module name deps transDeps sm sc) = - LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) - `PairE` sm `PairE` sc - {-# INLINE fromE #-} - - toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps - `PairE` sm `PairE` sc) = - Module name (S.fromList deps) (S.fromList transDeps) sm sc - {-# INLINE toE #-} - -instance SinkableE Module -instance HoistableE Module -instance AlphaEqE Module -instance AlphaHashableE Module -instance RenameE Module - -instance GenericE ImportStatus where - type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName - fromE (ImportStatus direct trans) = ListE (S.toList direct) - `PairE` ListE (S.toList trans) - {-# INLINE fromE #-} - toE (ListE direct `PairE` ListE trans) = - ImportStatus (S.fromList direct) (S.fromList trans) - {-# INLINE toE #-} - -instance SinkableE ImportStatus -instance HoistableE ImportStatus -instance AlphaEqE ImportStatus -instance AlphaHashableE ImportStatus -instance RenameE ImportStatus - -instance Semigroup (ImportStatus n) where - ImportStatus direct trans <> ImportStatus direct' trans' = - ImportStatus (direct <> direct') (trans <> trans') - -instance Monoid (ImportStatus n) where - mappend = (<>) - mempty = ImportStatus mempty mempty - -instance GenericE LoadedModules where - type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) - fromE (LoadedModules m) = - ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) - {-# INLINE toE #-} - -instance SinkableE LoadedModules -instance HoistableE LoadedModules -instance AlphaEqE LoadedModules -instance AlphaHashableE LoadedModules -instance RenameE LoadedModules - -instance GenericE LoadedObjects where - type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) - fromE (LoadedObjects m) = - ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) - {-# INLINE fromE #-} - toE (ListE pairs) = - LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) - {-# INLINE toE #-} - -instance SinkableE LoadedObjects -instance HoistableE LoadedObjects -instance RenameE LoadedObjects - -instance GenericE ModuleEnv where - type RepE ModuleEnv = ImportStatus - `PairE` SourceMap - `PairE` SynthCandidates - fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc - {-# INLINE fromE #-} - toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc - {-# INLINE toE #-} - -instance SinkableE ModuleEnv -instance HoistableE ModuleEnv -instance AlphaEqE ModuleEnv -instance AlphaHashableE ModuleEnv -instance RenameE ModuleEnv - -instance Semigroup (ModuleEnv n) where - ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = - ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) - -instance Monoid (ModuleEnv n) where - mempty = ModuleEnv mempty mempty mempty - -instance Semigroup (LoadedModules n) where - LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) - -instance Monoid (LoadedModules n) where - mempty = LoadedModules mempty - -instance Semigroup (LoadedObjects n) where - LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) - -instance Monoid (LoadedObjects n) where - mempty = LoadedObjects mempty - -instance Hashable InfVarDesc instance Hashable IxMethod instance Hashable ParamRole -instance Hashable a => Hashable (EvalStatus a) +instance Hashable BuiltinClassName instance IRRep r => Store (MiscOp r n) instance IRRep r => Store (VectorOp r n) instance IRRep r => Store (MemOp r n) -instance IRRep r => Store (TC r n) +instance IRRep r => Store (TyCon r n) instance IRRep r => Store (Con r n) instance IRRep r => Store (PrimOp r n) -instance IRRep r => Store (RepVal r n) +instance Store (RepVal n) instance IRRep r => Store (Type r n) instance IRRep r => Store (EffTy r n) +instance IRRep r => Store (Stuck r n) instance IRRep r => Store (Atom r n) instance IRRep r => Store (AtomVar r n) instance IRRep r => Store (Expr r n) -instance Store (SimpInCore n) -instance Store (SolverBinding n) -instance IRRep r => Store (AtomBinding r n) -instance Store (SpecializationSpec n) -instance Store (LinearizationSpec n) instance IRRep r => Store (DeclBinding r n) instance IRRep r => Store (Decl r n l) instance Store (TyConParams n) instance Store (DataConDefs n) instance Store (TyConDef n) instance Store (DataConDef n) -instance IRRep r => Store (TopLam r n) instance IRRep r => Store (LamExpr r n) instance IRRep r => Store (IxType r n) instance Store (CorePiType n) instance Store (CoreLamExpr n) instance IRRep r => Store (TabPiType r n) instance IRRep r => Store (DepPairType r n) -instance Store (AtomRules n) +instance Store BuiltinClassName instance Store (ClassDef n) instance Store (InstanceDef n) instance Store (InstanceBody n) instance Store (DictType n) -instance Store (DictExpr n) +instance IRRep r => Store (DictCon r n) instance Store (EffectDef n) instance Store (EffectOpDef n) instance Store (EffectOpType n) instance Store (EffectOpIdx) -instance Store (SynthCandidates n) -instance Store (Module n) -instance Store (ImportStatus n) -instance Store (TopFunLowerings n) -instance Store a => Store (EvalStatus a) -instance Store (TopFun n) -instance Store (TopFunDef n) -instance Color c => Store (Binding c n) -instance Store (ModuleEnv n) -instance Store (SerializedEnv n) instance Store (ann n) => Store (NonDepNest r ann n l) -instance Store InfVarDesc instance Store IxMethod instance Store ParamRole -instance Store (SpecializedDictDef n) +instance IRRep r => Store (Dict r n) instance IRRep r => Store (TypedHof r n) instance IRRep r => Store (Hof r n) instance IRRep r => Store (RefOp r n) instance IRRep r => Store (BaseMonoid r n) instance IRRep r => Store (DAMOp r n) -instance IRRep r => Store (IxDict r n) instance Store (NewtypeCon n) instance Store (NewtypeTyCon n) instance Store (DotMethods n) + +-- === Pretty instances === + +instance IRRep r => Pretty (Hof r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Hof r n) where + prettyPrec hof = atPrec LowestPrec case hof of + For _ _ lam -> "for" <+> pLowest lam + While body -> "while" <+> pArg body + RunReader x body -> "runReader" <+> pArg x <> nest 2 (line <> p body) + RunWriter _ bm body -> "runWriter" <+> pArg bm <> nest 2 (line <> p body) + RunState _ x body -> "runState" <+> pArg x <> nest 2 (line <> p body) + RunIO body -> "runIO" <+> pArg body + RunInit body -> "runInit" <+> pArg body + CatchException _ body -> "catchException" <+> pArg body + Linearize body x -> "linearize" <+> pArg body <+> pArg x + Transpose body x -> "transpose" <+> pArg body <+> pArg x + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (DAMOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DAMOp r n) where + prettyPrec op = atPrec LowestPrec case op of + Seq _ ann _ c lamExpr -> case lamExpr of + UnaryLamExpr b body -> do + "seq" <+> pApp ann <+> pApp c <+> prettyLam (pretty b <> ".") body + _ -> pretty (show op) -- shouldn't happen, but crashing pretty printers make debugging hard + RememberDest _ x y -> "rememberDest" <+> pArg x <+> pArg y + Place r v -> pApp r <+> "r:=" <+> pApp v + Freeze r -> "freeze" <+> pApp r + AllocDest ty -> "alloc" <+> pApp ty + +instance IRRep r => Pretty (TyCon r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (TyCon r n) where + prettyPrec con = case con of + BaseType b -> prettyPrec b + ProdType [] -> atPrec ArgPrec $ "()" + ProdType as -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pApp as + SumType cs -> atPrec ArgPrec $ align $ group $ + encloseSep "(|" "|)" " | " $ fmap pApp cs + RefType h a -> atPrec AppPrec $ pAppArg "Ref" [h] <+> p a + TypeKind -> atPrec ArgPrec "Type" + HeapType -> atPrec ArgPrec "Heap" + Pi piType -> atPrec LowestPrec $ align $ p piType + TabPi piType -> atPrec LowestPrec $ align $ p piType + DepPairTy ty -> prettyPrec ty + DictTy t -> atPrec LowestPrec $ p t + NewtypeTyCon con' -> prettyPrec con' + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecNewtype :: NewtypeCon n -> CAtom n -> DocPrec ann +prettyPrecNewtype con x = case (con, x) of + (NatCon, (IdxRepVal n)) -> atPrec ArgPrec $ pretty n + (_, x') -> prettyPrec x' + +instance Pretty (NewtypeTyCon n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (NewtypeTyCon n) where + prettyPrec = \case + Nat -> atPrec ArgPrec $ "Nat" + Fin n -> atPrec AppPrec $ "Fin" <+> pArg n + EffectRowKind -> atPrec ArgPrec "EffKind" + UserADTType name _ (TyConParams infs params) -> case (infs, params) of + ([], []) -> atPrec ArgPrec $ pretty name + ([Explicit, Explicit], [l, r]) + | Just sym <- fromInfix (fromString $ pprint name) -> + atPrec ArgPrec $ align $ group $ + parens $ flatAlt " " "" <> pApp l <> line <> pretty sym <+> pApp r + _ -> atPrec LowestPrec $ pAppArg (pretty name) $ ignoreSynthParams (TyConParams infs params) + where + fromInfix :: Text -> Maybe Text + fromInfix t = do + ('(', t') <- uncons t + (t'', ')') <- unsnoc t' + return t'' + +instance IRRep r => Pretty (Con r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Con r n) where + prettyPrec = \case + Lit l -> prettyPrec l + ProdCon [x] -> atPrec ArgPrec $ "(" <> pLowest x <> ",)" + ProdCon xs -> atPrec ArgPrec $ align $ group $ + encloseSep "(" ")" ", " $ fmap pLowest xs + SumCon _ tag payload -> atPrec ArgPrec $ + "(" <> p tag <> "|" <+> pApp payload <+> "|)" + HeapVal -> atPrec ArgPrec "HeapValue" + Lam lam -> atPrec LowestPrec $ p lam + DepPair x y _ -> atPrec ArgPrec $ align $ group $ + parens $ p x <+> ",>" <+> p y + Eff e -> atPrec ArgPrec $ p e + DictConAtom d -> atPrec LowestPrec $ p d + NewtypeCon con x -> prettyPrecNewtype con x + TyConAtom ty -> prettyPrec ty + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (PrimOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (PrimOp r n) where + prettyPrec = \case + MemOp op -> prettyPrec op + VectorOp op -> prettyPrec op + DAMOp op -> prettyPrec op + Hof (TypedHof _ hof) -> prettyPrec hof + RefOp ref eff -> atPrec LowestPrec case eff of + MAsk -> "ask" <+> pApp ref + MExtend _ x -> "extend" <+> pApp ref <+> pApp x + MGet -> "get" <+> pApp ref + MPut x -> pApp ref <+> ":=" <+> pApp x + IndexRef _ i -> pApp ref <+> "!" <+> pApp i + ProjRef _ i -> "proj_ref" <+> pApp ref <+> p i + UnOp op x -> prettyOpDefault (UUnOp op) [x] + BinOp op x y -> prettyOpDefault (UBinOp op) [x, y] + MiscOp op -> prettyOpGeneric op + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance IRRep r => Pretty (MemOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (MemOp r n) where + prettyPrec = \case + PtrOffset ptr idx -> atPrec LowestPrec $ pApp ptr <+> "+>" <+> pApp idx + PtrLoad ptr -> atPrec AppPrec $ pAppArg "load" [ptr] + op -> prettyOpGeneric op + +instance IRRep r => Pretty (VectorOp r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (VectorOp r n) where + prettyPrec = \case + VectorBroadcast v vty -> atPrec LowestPrec $ "vbroadcast" <+> pApp v <+> pApp vty + VectorIota vty -> atPrec LowestPrec $ "viota" <+> pApp vty + VectorIdx tbl i vty -> atPrec LowestPrec $ "vslice" <+> pApp tbl <+> pApp i <+> pApp vty + VectorSubref ref i _ -> atPrec LowestPrec $ "vrefslice" <+> pApp ref <+> pApp i + +prettyOpGeneric :: (IRRep r, GenericOp op, Show (OpConst op r)) => op r n -> DocPrec ann +prettyOpGeneric op = case fromEGenericOpRep op of + GenericOpRep op' [] [] [] -> atPrec ArgPrec (pretty $ show op') + GenericOpRep op' ts xs lams -> atPrec AppPrec $ pAppArg (pretty (show op')) xs <+> pretty ts <+> pretty lams + +instance Pretty IxMethod where + pretty method = pretty $ show method + +instance Pretty (TyConParams n) where + pretty (TyConParams _ _) = undefined + +instance Pretty (TyConDef n) where + pretty (TyConDef name _ bs cons) = "data" <+> pretty name <+> pretty bs <> pretty cons + +instance Pretty (DataConDefs n) where + pretty = undefined + +instance Pretty (DataConDef n) where + pretty (DataConDef name _ repTy _) = pretty name <+> ":" <+> pretty repTy + +instance Pretty (ClassDef n) where + pretty (ClassDef classSourceName _ methodNames _ _ params superclasses methodTys) = + "Class:" <+> pretty classSourceName <+> pretty methodNames + <> indented ( + line <> "parameter binders:" <+> pretty params <> + line <> "superclasses:" <+> pretty superclasses <> + line <> "methods:" <+> pretty methodTys) + +instance Pretty ParamRole where + pretty r = pretty (show r) + +instance Pretty (InstanceDef n) where + pretty (InstanceDef className _ bs params _) = + "Instance" <+> pretty className <+> pretty bs <+> pretty params + +instance IRRep r => Pretty (Expr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Expr r n) where + prettyPrec = \case + Atom x -> prettyPrec x + Block _ (Abs decls body) -> atPrec AppPrec $ prettyBlock decls body + App _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TopApp _ f xs -> atPrec AppPrec $ pApp f <+> spaced (toList xs) + TabApp _ f x -> atPrec AppPrec $ pApp f <> brackets (p x) + Case e alts (EffTy effs _) -> prettyPrecCase "case" e alts effs + TabCon _ _ es -> atPrec ArgPrec $ list $ pApp <$> es + PrimOp op -> prettyPrec op + ApplyMethod _ d i xs -> atPrec AppPrec $ "applyMethod" <+> p d <+> p i <+> p xs + Project _ i x -> atPrec AppPrec $ "Project" <+> p i <+> p x + Unwrap _ x -> atPrec AppPrec $ "Unwrap" <+> p x + where + p :: Pretty a => a -> Doc ann + p = pretty + +prettyPrecCase :: IRRep r => Doc ann -> Atom r n -> [Alt r n] -> EffectRow r n -> DocPrec ann +prettyPrecCase name e alts effs = atPrec LowestPrec $ + name <+> pApp e <+> "of" <> + nest 2 (foldMap (\alt -> hardline <> prettyAlt alt) alts + <> effectLine effs) + where + effectLine :: IRRep r => EffectRow r n -> Doc ann + effectLine Pure = "" + effectLine row = hardline <> "case annotated with effects" <+> pretty row + +prettyAlt :: IRRep r => Alt r n -> Doc ann +prettyAlt (Abs b body) = prettyBinderNoAnn b <+> "->" <> nest 2 (pretty body) + +prettyBinderNoAnn :: Binder r n l -> Doc ann +prettyBinderNoAnn (b:>_) = pretty b + +instance IRRep r => Pretty (DeclBinding r n) where + pretty (DeclBinding ann expr) = "Decl" <> pretty ann <+> pretty expr + +instance IRRep r => Pretty (Decl r n l) where + pretty (Let b (DeclBinding ann rhs)) = + align $ annDoc <> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + where annDoc = case ann of NoInlineLet -> pretty ann <> " "; _ -> pretty ann + +instance IRRep r => Pretty (PiType r n) where + pretty (PiType bs (EffTy effs resultTy)) = + (spaced $ unsafeFromNest $ bs) <+> "->" <+> "{" <> pretty effs <> "}" <+> pretty resultTy + +instance IRRep r => Pretty (LamExpr r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (LamExpr r n) where + prettyPrec (LamExpr bs body) = atPrec LowestPrec $ prettyLam (pretty bs <> ".") body + +instance IRRep r => Pretty (IxType r n) where + pretty (IxType ty dict) = parens $ "IxType" <+> pretty ty <> prettyIxDict dict + +instance IRRep r => Pretty (Dict r n) where + pretty = \case + DictCon con -> pretty con + StuckDict _ stuck -> pretty stuck + +instance IRRep r => Pretty (DictCon r n) where + pretty = \case + InstanceDict _ name args -> "Instance" <+> pretty name <+> pretty args + IxFin n -> "Ix (Fin" <+> pretty n <> ")" + DataData a -> "Data " <+> pretty a + IxRawFin n -> "Ix (RawFin " <> pretty n <> ")" + IxSpecialized d xs -> pretty d <+> pretty xs + +instance Pretty (DictType n) where + pretty = \case + DictType classSourceName _ params -> pretty classSourceName <+> spaced params + IxDictType ty -> "Ix" <+> pretty ty + DataDictType ty -> "Data" <+> pretty ty + +instance IRRep r => Pretty (DepPairType r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (DepPairType r n) where + prettyPrec (DepPairType _ b rhs) = + atPrec ArgPrec $ align $ group $ parensSep (spaceIfColinear <> "&> ") [pretty b, pretty rhs] + +instance Pretty (CoreLamExpr n) where + pretty (CoreLamExpr _ lam) = pretty lam + +instance IRRep r => Pretty (Atom r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Atom r n) where + prettyPrec atom = case atom of + Con e -> prettyPrec e + Stuck _ e -> prettyPrec e + +instance IRRep r => Pretty (Type r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Type r n) where + prettyPrec = \case + TyCon e -> prettyPrec e + StuckTy _ e -> prettyPrec e + +instance IRRep r => Pretty (Stuck r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (Stuck r n) where + prettyPrec = \case + Var v -> atPrec ArgPrec $ p v + StuckProject i v -> atPrec LowestPrec $ "StuckProject" <+> p i <+> p v + StuckTabApp f xs -> atPrec AppPrec $ pArg f <> "." <> pArg xs + StuckUnwrap v -> atPrec LowestPrec $ "StuckUnwrap" <+> p v + InstantiatedGiven v args -> atPrec LowestPrec $ "Given" <+> p v <+> p (toList args) + SuperclassProj d' i -> atPrec LowestPrec $ "SuperclassProj" <+> p d' <+> p i + PtrVar _ v -> atPrec ArgPrec $ p v + RepValAtom x -> atPrec LowestPrec $ pretty x + ACase e alts _ -> atPrec AppPrec $ "acase" <+> p e <+> p alts + LiftSimp ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" + LiftSimpFun ty x -> atPrec ArgPrec $ " p x <+> " : " <+> p ty <+> ">" + TabLam lam -> atPrec AppPrec $ "tablam" <+> p lam + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance PrettyPrec (AtomVar r n) where + prettyPrec (AtomVar v _) = prettyPrec v +instance Pretty (AtomVar r n) where pretty = prettyFromPrettyPrec + +instance IRRep r => Pretty (EffectRow r n) where + pretty (EffectRow effs t) = braces $ hsep (punctuate "," (map pretty (eSetToList effs))) <> pretty t + +instance IRRep r => Pretty (EffectRowTail r n) where + pretty = \case + NoTail -> mempty + EffectRowTail v -> "|" <> pretty v + +instance IRRep r => Pretty (Effect r n) where + pretty eff = case eff of + RWSEffect rws h -> pretty rws <+> pretty h + ExceptionEffect -> "Except" + IOEffect -> "IO" + InitEffect -> "Init" + +prettyLam :: Pretty a => Doc ann -> a -> Doc ann +prettyLam binders body = group $ group (nest 4 $ binders) <> group (nest 2 $ pretty body) + +instance IRRep r => Pretty (TabPiType r n) where + pretty (TabPiType dict (b:>ty) body) = let + prettyBody = case body of + TyCon (Pi subpi) -> pretty subpi + _ -> pLowest body + prettyBinder = prettyBinderHelper (b:>ty) body + in prettyBinder <> prettyIxDict dict <> (group $ line <> "=>" <+> prettyBody) + +-- A helper to let us turn dict printing on and off. We mostly want it off to +-- reduce clutter in prints and error messages, but when debugging synthesis we +-- want it on. +prettyIxDict :: IRRep r => IxDict r n -> Doc ann +prettyIxDict dict = if False then " " <> pretty dict else mempty + +prettyBinderHelper :: IRRep r => HoistableE e => Binder r n l -> e l -> Doc ann +prettyBinderHelper (b:>ty) body = + if binderName b `isFreeIn` body + then parens $ pretty (b:>ty) + else pretty ty + +instance Pretty (CorePiType n) where + pretty (CorePiType appExpl expls bs (EffTy eff resultTy)) = + prettyBindersWithExpl expls bs <+> pretty appExpl <> prettyEff <> pretty resultTy + where + prettyEff = case eff of + Pure -> space + _ -> space <> pretty eff <> space + +prettyBindersWithExpl :: forall b n l ann. PrettyB b + => [Explicitness] -> Nest b n l -> Doc ann +prettyBindersWithExpl expls bs = do + let groups = groupByExpl $ zip expls (unsafeFromNest bs) + let groups' = case groups of [] -> [(Explicit, [])] + _ -> groups + mconcat [withExplParens expl $ commaSep bsGroup | (expl, bsGroup) <- groups'] + +groupByExpl :: [(Explicitness, b UnsafeS UnsafeS)] -> [(Explicitness, [b UnsafeS UnsafeS])] +groupByExpl [] = [] +groupByExpl ((expl, b):bs) = do + let (matches, rest) = span (\(expl', _) -> expl == expl') bs + let matches' = map snd matches + (expl, b:matches') : groupByExpl rest + +withExplParens :: Explicitness -> Doc ann -> Doc ann +withExplParens Explicit x = parens x +withExplParens (Inferred _ Unify) x = braces $ x +withExplParens (Inferred _ (Synth _)) x = brackets x + +instance Pretty (RepVal n) where + pretty (RepVal ty tree) = " pretty tree <+> ":" <+> pretty ty <> ">" + +prettyBlock :: (IRRep r, PrettyPrec (e l)) => Nest (Decl r) n l -> e l -> Doc ann +prettyBlock Empty expr = group $ line <> pLowest expr +prettyBlock decls expr = prettyLines decls' <> hardline <> pLowest expr + where decls' = unsafeFromNest decls + +instance IRRep r => Pretty (BaseMonoid r n) where pretty = prettyFromPrettyPrec +instance IRRep r => PrettyPrec (BaseMonoid r n) where + prettyPrec (BaseMonoid x f) = + atPrec LowestPrec $ "baseMonoid" <+> pArg x <> nest 2 (line <> pArg f) diff --git a/src/lib/Types/Imp.hs b/src/lib/Types/Imp.hs index d99d66c4a..9006745ce 100644 --- a/src/lib/Types/Imp.hs +++ b/src/lib/Types/Imp.hs @@ -27,11 +27,16 @@ import qualified Data.ByteString as BS import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) +import Data.Text.Prettyprint.Doc (line', nest, group) import Name +import PPrint import Util (IsBool (..)) - import Types.Primitives +import Types.Source + +-- === data types === type ImpName = Name ImpNameC @@ -480,3 +485,91 @@ instance Store LinktimeVals instance Hashable IsCUDARequired instance Hashable CallingConvention instance Hashable IFunType + +instance Pretty CallingConvention where pretty = fromString . show + +instance Pretty (ImpFunction n) where + pretty (ImpFunction (IFunType cc _ _) (Abs bs body)) = + "impfun" <+> pretty cc <+> prettyBinderNest bs + <> nest 2 (hardline <> pretty body) <> hardline + +instance Pretty (ImpBlock n) where + pretty = \case + ImpBlock Empty [] -> mempty + ImpBlock Empty expr -> group $ hardline <> pLowest expr + ImpBlock decls [] -> prettyLines $ fromNest decls + ImpBlock decls expr -> prettyLines decls' <> hardline <> pLowest expr + where decls' = fromNest decls + +instance Pretty (IBinder n l) where + pretty (IBinder b ty) = pretty b <+> ":" <+> pretty ty + +instance Pretty (ImpInstr n) where + pretty = \case + IFor a n (Abs i block) -> forStr a <+> p i <+> "<" <+> p n <> + nest 4 (p block) + IWhile body -> "while" <+> nest 2 (p body) + ICond predicate cons alt -> + "if" <+> p predicate <+> "then" <> nest 2 (p cons) <> + hardline <> "else" <> nest 2 (p alt) + IQueryParallelism f s -> "queryParallelism" <+> p f <+> p s + ILaunch f s args -> "launch" <+> p f <+> p s <+> spaced args + ICastOp t x -> "cast" <+> p x <+> "to" <+> p t + IBitcastOp t x -> "bitcast" <+> p x <+> "to" <+> p t + Store dest val -> "store" <+> p dest <+> p val + Alloc _ t s -> "alloc" <+> p t <> "[" <> sizeStr s <> "]" + StackAlloc t s -> "alloca" <+> p t <> "[" <> sizeStr s <> "]" + MemCopy dest src numel -> "memcopy" <+> p dest <+> p src <+> p numel + InitializeZeros ptr numel -> "initializeZeros" <+> p ptr <+> p numel + GetAllocSize ptr -> "getAllocSize" <+> p ptr + Free ptr -> "free" <+> p ptr + ISyncWorkgroup -> "syncWorkgroup" + IThrowError -> "throwError" + ICall f args -> "call" <+> p f <+> p args + IVectorBroadcast v _ -> "vbroadcast" <+> p v + IVectorIota _ -> "viota" + DebugPrint s x -> "debug_print" <+> p (show s) <+> p x + IPtrLoad ptr -> "load" <+> p ptr + IPtrOffset ptr idx -> p ptr <+> "+>" <+> p idx + IBinOp op x y -> opDefault (UBinOp op) [x, y] + IUnOp op x -> opDefault (UUnOp op) [x] + ISelect x y z -> "select" <+> p x <+> p y <+> p z + IOutputStream -> "outputStream" + IShowScalar ptr x -> "show_scalar" <+> p ptr <+> p x + where opDefault name xs = prettyOpDefault name xs $ AppPrec + p :: Pretty a => a -> Doc ann + p = pretty + forStr :: ForAnn -> Doc ann + forStr = \case + Fwd -> "for" + Rev -> "rof" + +sizeStr :: IExpr n -> Doc ann +sizeStr s = case s of + ILit (Word32Lit x) -> pretty x -- print in decimal because it's more readable + _ -> pretty s + +instance Pretty (IExpr n) where + pretty = \case + ILit v -> pretty v + IVar v _ -> pretty v + IPtrVar v _ -> pretty v + +instance PrettyPrec (IExpr n) where prettyPrec = atPrec ArgPrec . pretty + +instance Pretty (ImpDecl n l) where + pretty = \case + ImpLet Empty instr -> pretty instr + ImpLet (Nest b Empty) instr -> pretty b <+> "=" <+> pretty instr + ImpLet bs instr -> pretty bs <+> "=" <+> pretty instr + +instance Pretty IFunType where + pretty (IFunType cc argTys retTys) = + "Fun" <+> pretty cc <+> pretty argTys <+> "->" <+> pretty retTys + +prettyBinderNest :: PrettyB b => Nest b n l -> Doc ann +prettyBinderNest bs = nest 6 $ line' <> (sep $ map pretty $ fromNest bs) + +fromNest :: Nest b n l -> [b UnsafeS UnsafeS] +fromNest Empty = [] +fromNest (Nest b rest) = unsafeCoerceB b : fromNest rest diff --git a/src/lib/Types/Misc.hs b/src/lib/Types/Misc.hs deleted file mode 100644 index 71416eead..000000000 --- a/src/lib/Types/Misc.hs +++ /dev/null @@ -1,36 +0,0 @@ --- Copyright 2022 Google LLC --- --- Use of this source code is governed by a BSD-style --- license that can be found in the LICENSE file or at --- https://developers.google.com/open-source/licenses/bsd - -module Types.Misc where - -import GHC.Generics (Generic (..)) - -import Err -import Logging -import Types.Source - -type LitProg = [(SourceBlock, Result)] - -data Result = Result - { resultOutputs :: [Output] - , resultErrs :: Except () } - deriving (Show, Eq) - -type BenchStats = (Int, Double) -- number of runs, total benchmarking time -data Output = - TextOut String - | HtmlOut String - | PassInfo PassName String - | EvalTime Double (Maybe BenchStats) - | TotalTime Double - | BenchResult String Double Double (Maybe BenchStats) -- name, compile time, eval time - | MiscLog String - -- Used to have | ExportedFun String Atom - deriving (Show, Eq, Generic) - -type PassLogger = FilteredLogger PassName [Output] - -data OptLevel = NoOptimize | Optimize diff --git a/src/lib/Types/OpNames.hs b/src/lib/Types/OpNames.hs index 178936ec7..344329ac6 100644 --- a/src/lib/Types/OpNames.hs +++ b/src/lib/Types/OpNames.hs @@ -14,6 +14,8 @@ import Data.Hashable import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import PPrint + data TC = ProdType | SumType | RefType | TypeKind | HeapType data Con = ProdCon | SumCon Int | HeapVal @@ -117,3 +119,8 @@ deriving instance Eq (Hof r) deriving instance Eq DAMOp deriving instance Eq RefOp deriving instance Eq UserEffectOp + +instance Pretty Projection where + pretty = \case + UnwrapNewtype -> "u" + ProjectProduct i -> pretty i diff --git a/src/lib/Types/Primitives.hs b/src/lib/Types/Primitives.hs index 002a6d09a..f449acba6 100644 --- a/src/lib/Types/Primitives.hs +++ b/src/lib/Types/Primitives.hs @@ -24,18 +24,23 @@ module Types.Primitives ( import qualified Data.ByteString as BS import Data.Int +import Data.String (IsString (..)) import Data.Word import Data.Hashable import Data.Store (Store (..)) import qualified Data.Store.Internal as SI import Foreign.Ptr +import Numeric +import GHC.Float import GHC.Generics (Generic (..)) +import PPrint import Occurrence import Types.OpNames (UnOp (..), BinOp (..), CmpOp (..), Projection (..)) +import Name -type SourceName = String +newtype SourceName = MkSourceName String deriving (Show, Eq, Ord, Generic) newtype AlwaysEqual a = AlwaysEqual a deriving (Show, Generic, Functor, Foldable, Traversable, Hashable, Store) @@ -60,8 +65,11 @@ data RequiredMethodAccess = Full | Partial Int deriving (Show, Eq, Ord, Generic) data LetAnn = -- Binding with no additional information PlainLet + -- Binding explicitly tagged "inline immediately" + | InlineLet -- Binding explicitly tagged "do not inline" | NoInlineLet + | LinearLet -- Bound expression is pure, and the binding's occurrences are summarized by -- the UsageInfo | OccInfoPure UsageInfo @@ -178,6 +186,16 @@ emptyLit = \case -- === Typeclass instances === +instance HasNameHint SourceName where + getNameHint (MkSourceName v) = getNameHint v + +instance Pretty SourceName where + pretty (MkSourceName v) = pretty v + +instance IsString SourceName where + fromString v = MkSourceName v + +instance Store SourceName instance Store RequiredMethodAccess instance Store LetAnn instance Store RWS @@ -191,6 +209,7 @@ instance Store AppExplicitness instance Store DepPairExplicitness instance Store InferenceMechanism +instance Hashable SourceName instance Hashable RWS instance Hashable Direction instance Hashable BaseType @@ -205,3 +224,75 @@ instance Hashable AppExplicitness instance Hashable DepPairExplicitness instance Hashable InferenceMechanism instance Hashable RequiredMethodAccess + +-- === Pretty instances === + +instance Pretty AppExplicitness where + pretty ExplicitApp = "->" + pretty ImplicitApp = "->>" + +instance Pretty RWS where + pretty eff = case eff of + Reader -> "Read" + Writer -> "Accum" + State -> "State" + +instance Pretty LetAnn where + pretty ann = case ann of + PlainLet -> "" + InlineLet -> "%inline" + NoInlineLet -> "%noinline" + LinearLet -> "%linear" + OccInfoPure u -> pretty u <> hardline + OccInfoImpure u -> pretty u <> ", impure" <> hardline + +instance PrettyPrec Direction where + prettyPrec d = atPrec ArgPrec $ case d of + Fwd -> "fwd" + Rev -> "rev" + +printDouble :: Double -> Doc ann +printDouble x = pretty (double2Float x) + +printFloat :: Float -> Doc ann +printFloat x = pretty $ reverse $ dropWhile (=='0') $ reverse $ + showFFloat (Just 6) x "" + +instance Pretty LitVal where pretty = prettyFromPrettyPrec +instance PrettyPrec LitVal where + prettyPrec = \case + Int64Lit x -> atPrec ArgPrec $ p x + Int32Lit x -> atPrec ArgPrec $ p x + Float64Lit x -> atPrec ArgPrec $ printDouble x + Float32Lit x -> atPrec ArgPrec $ printFloat x + Word8Lit x -> atPrec ArgPrec $ p $ show $ toEnum @Char $ fromIntegral x + Word32Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + Word64Lit x -> atPrec ArgPrec $ p $ "0x" ++ showHex x "" + PtrLit ty (PtrLitVal x) -> atPrec ArgPrec $ "Ptr" <+> p ty <+> p (show x) + PtrLit _ NullPtr -> atPrec ArgPrec $ "NullPtr" + PtrLit _ (PtrSnapshot _) -> atPrec ArgPrec "" + where p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty Device where pretty = fromString . show + +instance Pretty BaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec BaseType where + prettyPrec b = case b of + Scalar sb -> prettyPrec sb + Vector shape sb -> atPrec ArgPrec $ encloseSep "<" ">" "x" $ (pretty <$> shape) ++ [pretty sb] + PtrType ty -> atPrec AppPrec $ "Ptr" <+> pretty ty + +instance Pretty ScalarBaseType where pretty = prettyFromPrettyPrec +instance PrettyPrec ScalarBaseType where + prettyPrec sb = atPrec ArgPrec $ case sb of + Int64Type -> "Int64" + Int32Type -> "Int32" + Float64Type -> "Float64" + Float32Type -> "Float32" + Word8Type -> "Word8" + Word32Type -> "Word32" + Word64Type -> "Word64" + +instance Pretty Explicitness where + pretty expl = pretty (show expl) diff --git a/src/lib/Types/Source.hs b/src/lib/Types/Source.hs index 21a6974ee..b20a0e09a 100644 --- a/src/lib/Types/Source.hs +++ b/src/lib/Types/Source.hs @@ -20,42 +20,38 @@ module Types.Source where -import Data.Data +import Data.Aeson (ToJSON) import Data.Hashable import Data.Foldable import qualified Data.Map.Strict as M import qualified Data.Set as S -import Data.String (IsString, fromString) import Data.Text (Text) -import Data.Text.Prettyprint.Doc (Pretty (..), hardline, (<+>)) import Data.Word +import Data.Text.Prettyprint.Doc (vcat, line, group, parens, nest, align, punctuate, hsep) +import Data.Text (snoc, unsnoc) +import Data.Tuple (swap) import GHC.Generics (Generic (..)) import Data.Store (Store (..)) +import Data.String (fromString) +import Err +import PPrint import Name import qualified Types.OpNames as P import IRVariants -import SourceInfo -import Util (File (..)) +import MonadUtil +import Util (File (..), SnocList) +import IncState import Types.Primitives -data SourceName' = SourceName' SrcPosCtx SourceName - deriving (Show, Eq, Ord, Generic) - -fromName :: SourceName -> SourceName' -fromName = SourceName' emptySrcPosCtx - -instance HasNameHint SourceName' where - getNameHint (SourceName' _ name) = getNameHint name - data SourceNameOr (a::E) (n::S) where -- Only appears before renaming pass - SourceName :: SrcPosCtx -> SourceName -> SourceNameOr a n + SourceName :: SrcId -> SourceName -> SourceNameOr a n -- Only appears after renaming pass -- We maintain the source name for user-facing error messages. - InternalName :: SrcPosCtx -> SourceName -> a n -> SourceNameOr a n + InternalName :: SrcId -> SourceName -> a n -> SourceNameOr a n deriving instance Eq (a n) => Eq (SourceNameOr a n) deriving instance Ord (a n) => Ord (SourceNameOr a n) deriving instance Show (a n) => Show (SourceNameOr a n) @@ -63,109 +59,185 @@ deriving instance Show (a n) => Show (SourceNameOr a n) newtype SourceOrInternalName (c::C) (n::S) = SourceOrInternalName (SourceNameOr (Name c) n) deriving (Eq, Ord, Show, Generic) -pattern SISourceName :: (n ~ VoidS) => SourceName -> SourceOrInternalName c n -pattern SISourceName n = SourceOrInternalName (SourceName EmptySrcPosCtx n) +-- === Source Info === + +-- This is just for syntax highlighting. It won't be needed if we have +-- a separate lexing pass where we have a complete lossless data type for +-- lexemes. +data LexemeType = + Keyword + | Symbol + | TypeName + | LowerName + | UpperName + | LiteralLexeme + | StringLiteralLexeme + | MiscLexeme + deriving (Show, Generic) + +type Span = (Int, Int) +data LexemeInfo = LexemeInfo + { lexemeList :: SnocList SrcId + , lexemeInfo :: M.Map SrcId (LexemeType, Span) } + deriving (Show, Generic) + +type LexemeId = SrcId +type LexemeSpan = (LexemeId, LexemeId) +data GroupTree = GroupTree + { gtSrcId :: SrcId + , gtSpan :: LexemeSpan + , gtChildren :: [GroupTree] + , gtIsAtomicLexeme :: Bool } + deriving (Show, Eq, Generic) + +instance Semigroup LexemeInfo where + LexemeInfo a b <> LexemeInfo a' b' = LexemeInfo (a <> a') (b <> b') +instance Monoid LexemeInfo where + mempty = LexemeInfo mempty mempty + +-- === Type Info === + +newtype TypeInfo = TypeInfo { fromTypeInfo :: M.Map SrcId String } + deriving (Semigroup, Monoid, ToJSON, Show, Eq) + +-- === Results === + +type LitProg = [(SourceBlock, Result)] -pattern SIInternalName :: SourceName -> Name c n -> Maybe SrcPos -> Maybe SpanId -> SourceOrInternalName c n -pattern SIInternalName n a srcPos spanId = SourceOrInternalName (InternalName (SrcPosCtx srcPos spanId) n a) +data Result = Result + { resultOutputs :: Outputs + , resultErrs :: Except () } + deriving (Show, Eq) + +type BenchStats = (Int, Double) -- number of runs, total benchmarking time + +data SourceInfo = + SIGroupTree (Overwrite GroupTree) + | SITypeInfo TypeInfo + deriving (Show, Eq, Generic) + +data Output = + TextOut String + | HtmlOut String + | SourceInfo SourceInfo -- for hovertips etc + | PassInfo PassName String + | MiscLog String + | Error Err + deriving (Show, Eq, Generic) +newtype Outputs = Outputs { fromOutputs :: [Output] } + deriving (Show, Eq, Generic, Semigroup, Monoid) + +type PassLogger = IOLogger Outputs + +data OptLevel = NoOptimize | Optimize + +instance Semigroup Result where + Result outs err <> Result outs' err' = Result (outs <> outs') err'' + where err'' = case err' of + Success () -> err + Failure _ -> err' + +instance Monoid Result where + mempty = Result mempty (Success ()) -- === Concrete syntax === -- The grouping-level syntax of the source language +-- aliases for the "with source ID versions" + +type GroupW = WithSrcs Group +type CTopDeclW = WithSrcs CTopDecl +type CSDeclW = WithSrcs CSDecl +type SourceNameW = WithSrc SourceName +type BinW = WithSrc Bin + +type BracketedGroup = WithSrcs [GroupW] -- optional arrow, effects, result type -type ExplicitParams = [Group] -type GivenClause = ([Group], Maybe [Group]) -- implicits, classes -type WithClause = [Group] -- no classes because we don't want to carry class dicts at runtime +type ExplicitParams = BracketedGroup +type GivenClause = (BracketedGroup, Maybe BracketedGroup) -- implicits, classes +type WithClause = BracketedGroup -- no classes because we don't want to carry class dicts at runtime -type CTopDecl = WithSrc CTopDecl' -data CTopDecl' - = CSDecl LetAnn CSDecl' +data CTopDecl + = CSDecl LetAnn CSDecl | CData - SourceName -- Type constructor name - ExplicitParams + SourceNameW -- Type constructor name + (Maybe ExplicitParams) (Maybe GivenClause) - [(SourceName, ExplicitParams)] -- Constructor names and argument sets + [(SourceNameW, Maybe ExplicitParams)] -- Constructor names and argument sets | CStruct - SourceName -- Type constructor name - ExplicitParams + SourceNameW -- Type constructor name + (Maybe ExplicitParams) (Maybe GivenClause) - [(SourceName, Group)] -- Field names and types + [(SourceNameW, GroupW)] -- Field names and types [(LetAnn, CDef)] | CInterface - SourceName -- Interface name + SourceNameW -- Interface name ExplicitParams - [(SourceName, Group)] -- Method declarations - | CEffectDecl SourceName [(SourceName, UResumePolicy, Group)] - | CHandlerDecl SourceName -- Handler name - SourceName -- Effect name - SourceName -- Body type parameter - Group -- Handler arguments - Group -- Handler type annotation - [(SourceName, Maybe UResumePolicy, CSBlock)] -- Handler methods + [(SourceNameW, GroupW)] -- Method declarations -- header, givens (may be empty), methods, optional name. The header should contain -- the prerequisites, class name, and class arguments. | CInstanceDecl CInstanceDef deriving (Show, Generic) -type CSDecl = WithSrc CSDecl' -data CSDecl' - = CLet Group CSBlock +data CSDecl + = CLet GroupW CSBlock | CDefDecl CDef - | CExpr Group - | CBind Group CSBlock -- Arrow binder <- + | CExpr GroupW + | CBind GroupW CSBlock -- Arrow binder <- | CPass deriving (Show, Generic) -type CEffs = ([Group], Maybe Group) +type CEffs = WithSrcs ([GroupW], Maybe GroupW) data CDef = CDef - SourceName - (ExplicitParams) + SourceNameW + ExplicitParams (Maybe CDefRhs) (Maybe GivenClause) CSBlock deriving (Show, Generic) -type CDefRhs = (AppExplicitness, Maybe CEffs, Group) +type CDefRhs = (AppExplicitness, Maybe CEffs, GroupW) data CInstanceDef = CInstanceDef - SourceName -- interface name - [Group] -- args at which we're instantiating the interface + SourceNameW -- interface name + [GroupW] -- args at which we're instantiating the interface (Maybe GivenClause) - [CSDecl] -- Method definitions - (Maybe (SourceName, Maybe [Group])) -- Optional name of instance, with explicit parameters + [CSDeclW] -- Method definitions + (Maybe (SourceNameW, Maybe BracketedGroup)) -- Optional name of instance, with explicit parameters deriving (Show, Generic) -type Group = WithSrc Group' -data Group' - = CEmpty - | CIdentifier SourceName - | CPrim PrimName [Group] +data Group + = CLeaf CLeaf + | CPrim PrimName [GroupW] + | CParens [GroupW] + | CBrackets [GroupW] + | CBin BinW GroupW GroupW + | CJuxtapose Bool GroupW GroupW -- Bool means "there's a space between the groups" + | CPrefix SourceNameW GroupW -- covers unary - and unary + among others + | CGivens GivenClause + | CLambda [GroupW] CSBlock + | CFor ForKind [GroupW] CSBlock -- also for_, rof, rof_ + | CCase GroupW [CaseAlt] -- scrutinee, alternatives + | CIf GroupW CSBlock (Maybe CSBlock) + | CDo CSBlock + | CArrow GroupW (Maybe CEffs) GroupW + | CWith GroupW WithClause + deriving (Show, Generic) + +data CLeaf + = CIdentifier SourceName | CNat Word64 | CInt Int | CString String | CChar Char | CFloat Double | CHole - | CParens [Group] - | CBrackets [Group] - | CBin Bin Group Group - | CPrefix SourceName Group -- covers unary - and unary + among others - | CPostfix SourceName Group - | CLambda [Group] CSBlock - | CFor ForKind [Group] CSBlock -- also for_, rof, rof_ - | CCase Group [(Group, CSBlock)] -- scrutinee, alternatives - | CIf Group CSBlock (Maybe CSBlock) - | CDo CSBlock - | CGivens GivenClause - | CArrow Group (Maybe CEffs) Group - | CWith Group WithClause deriving (Show, Generic) -type Bin = WithSrc Bin' -data Bin' - = JuxtaposeWithSpace - | JuxtaposeNoSpace - | EvalBinOp String +type CaseAlt = (GroupW, CSBlock) -- scrutinee, lexeme Id, body + +data Bin + = EvalBinOp SourceName | DepAmpersand | Dot | DepComma @@ -176,7 +248,7 @@ data Bin' | FatArrow -- => | Pipe | CSEqual - deriving (Eq, Ord, Show, Generic) + deriving (Show, Generic) data LabelPrefix = PlainLabel deriving (Show, Generic) @@ -190,8 +262,8 @@ data ForKind -- `CSBlock` instead of `CBlock` because the latter is an alias for `Block CoreIR`. data CSBlock = - IndentedBlock [CSDecl] -- last decl should be a CExpr - | ExprBlock Group + IndentedBlock SrcId [CSDeclW] -- last decl should be a CExpr + | ExprBlock GroupW deriving (Show, Generic) -- === Untyped IR === @@ -216,22 +288,21 @@ data UVar (n::S) = | UTyConVar (Name TyConNameC n) | UDataConVar (Name DataConNameC n) | UClassVar (Name ClassNameC n) - | UEffectVar (Name EffectNameC n) | UMethodVar (Name MethodNameC n) - | UEffectOpVar (Name EffectOpNameC n) | UPunVar (Name TyConNameC n) -- for names also used as data constructors deriving (Eq, Ord, Show, Generic) type UAtomBinder = UBinder (AtomNameC CoreIR) -data UBinder (c::C) (n::S) (l::S) where +type UBinder c = WithSrcB (UBinder' c) +data UBinder' (c::C) (n::S) (l::S) where -- Only appears before renaming pass - UBindSource :: SrcPosCtx -> SourceName -> UBinder c n n + UBindSource :: SourceName -> UBinder' c n n -- May appear before or after renaming pass - UIgnore :: UBinder c n n + UIgnore :: UBinder' c n n -- The following binders only appear after the renaming pass. -- We maintain the source name for user-facing error messages -- and named arguments. - UBind :: SrcPosCtx -> SourceName -> NameBinder c n l -> UBinder c n l + UBind :: SourceName -> NameBinder c n l -> UBinder' c n l type UBlock = WithSrcE UBlock' data UBlock' (n::S) where @@ -274,12 +345,9 @@ data FieldName' = | FieldNum Int deriving (Show, Eq, Ord) -type UAnnExplBinders req n l = ([Explicitness], Nest (UAnnBinder req) n l) -type UOptAnnExplBinders n l = UAnnExplBinders AnnOptional n l - data ULamExpr (n::S) where ULamExpr - :: UOptAnnExplBinders n l -- args + :: Nest UAnnBinder n l -- args -> AppExplicitness -> Maybe (UEffectRow l) -- optional effect -> Maybe (UType l) -- optional result type @@ -287,33 +355,33 @@ data ULamExpr (n::S) where -> ULamExpr n data UPiExpr (n::S) where - UPiExpr :: UOptAnnExplBinders n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n + UPiExpr :: Nest UAnnBinder n l -> AppExplicitness -> UEffectRow l -> UType l -> UPiExpr n data UTabPiExpr (n::S) where - UTabPiExpr :: UOptAnnBinder n l -> UType l -> UTabPiExpr n + UTabPiExpr :: UAnnBinder n l -> UType l -> UTabPiExpr n data UDepPairType (n::S) where - UDepPairType :: DepPairExplicitness -> UOptAnnBinder n l -> UType l -> UDepPairType n + UDepPairType :: DepPairExplicitness -> UAnnBinder n l -> UType l -> UDepPairType n -type UConDef (n::S) (l::S) = (SourceName, Nest UReqAnnBinder n l) +type UConDef (n::S) (l::S) = (SourceName, Nest UAnnBinder n l) data UDataDef (n::S) where UDataDef :: SourceName -- source name for pretty printing - -> UOptAnnExplBinders n l + -> Nest UAnnBinder n l -> [(SourceName, UDataDefTrail l)] -- data constructor types -> UDataDef n data UStructDef (n::S) where UStructDef :: SourceName -- source name for pretty printing - -> UOptAnnExplBinders n l - -> [(SourceName, UType l)] -- named payloads + -> Nest UAnnBinder n l + -> [(SourceNameW, UType l)] -- named payloads -> [(LetAnn, SourceName, Abs UAtomBinder ULamExpr l)] -- named methods (initial binder is for `self`) -> UStructDef n data UDataDefTrail (l::S) where - UDataDefTrail :: Nest UReqAnnBinder l l' -> UDataDefTrail l + UDataDefTrail :: Nest UAnnBinder l l' -> UDataDefTrail l data UTopDecl (n::S) (l::S) where ULocalDecl :: UDecl n l -> UTopDecl n l @@ -327,42 +395,24 @@ data UTopDecl (n::S) (l::S) where -> UStructDef l -- actual definition -> UTopDecl n l UInterface - :: UOptAnnExplBinders n p -- parameter binders + :: Nest UAnnBinder n p -- parameter binders -> [UType p] -- method types -> UBinder ClassNameC n l' -- class name -> Nest (UBinder MethodNameC) l' l -- method names -> UTopDecl n l UInstance :: SourceNameOr (Name ClassNameC) n -- class name - -> UOptAnnExplBinders n l' + -> Nest UAnnBinder n l' -> [UExpr l'] -- class parameters -> [UMethodDef l'] -- method definitions -- Maybe we should make a separate color (namespace) for instance names? -> MaybeB UAtomBinder n l -- optional instance name -> AppExplicitness -- explicitness (only relevant for named instances) -> UTopDecl n l - UEffectDecl - :: [UEffectOpType n] -- operation types - -> UBinder EffectNameC n l' -- effect name - -> Nest (UBinder EffectOpNameC) l' l -- operation names - -> UTopDecl n l - UHandlerDecl - :: SourceNameOr (Name EffectNameC) n -- effect name - -> UAtomBinder n b -- body type argument - -> UOptAnnExplBinders b l' -- type args - -> UEffectRow l' -- returning effect - -> UType l' -- returning type - -> [UEffectOpDef l'] -- operation definitions - -> UBinder HandlerNameC n l -- handler name - -> UTopDecl n l type UType = UExpr type UConstraint = UExpr -data UEffectOpType (n::S) where - UEffectOpType :: UResumePolicy -> UType s -> UEffectOpType s - deriving (Show, Generic) - data UResumePolicy = UNoResume | ULinearResume @@ -373,32 +423,19 @@ instance Hashable UResumePolicy instance Store UResumePolicy data UForExpr (n::S) where - UForExpr :: UOptAnnBinder n l -> UBlock l -> UForExpr n + UForExpr :: UAnnBinder n l -> UBlock l -> UForExpr n type UMethodDef = WithSrcE UMethodDef' data UMethodDef' (n::S) = UMethodDef (SourceNameOr (Name MethodNameC) n) (ULamExpr n) deriving (Show, Generic) -data UEffectOpDef (n::S) = - UEffectOpDef UResumePolicy (SourceNameOr (Name EffectOpNameC) n) (UExpr n) - | UReturnOpDef (UExpr n) - deriving (Show, Generic) - -data AnnRequirement = AnnRequired | AnnOptional - -data UAnn (annReq::AnnRequirement) (n::S) where - UAnn :: UType n -> UAnn annReq n - UNoAnn :: UAnn AnnOptional n -deriving instance Show (UAnn annReq n) - +data UAnn (n::S) = UAnn (UType n) | UNoAnn deriving Show -data UAnnBinder (annReq::AnnRequirement) (n::S) (l::S) = - UAnnBinder (UAtomBinder n l) (UAnn annReq n) [UConstraint n] +-- TODO: SrcId +data UAnnBinder (n::S) (l::S) = + UAnnBinder Explicitness (UAtomBinder n l) (UAnn n) [UConstraint n] deriving (Show, Generic) -type UReqAnnBinder = UAnnBinder AnnRequired :: B -type UOptAnnBinder = UAnnBinder AnnOptional :: B - data UAlt (n::S) where UAlt :: UPat n l -> UBlock l -> UAlt n @@ -411,42 +448,92 @@ data UPat' (n::S) (l::S) = | UPatTable (Nest UPat n l) deriving (Show, Generic) -pattern UPatIgnore :: UPat' (n::S) n -pattern UPatIgnore = UPatBinder UIgnore - -- === source names for error messages === class HasSourceName a where getSourceName :: a -> SourceName -instance HasSourceName (UAnnBinder req n l) where - getSourceName (UAnnBinder b _ _) = getSourceName b +instance HasSourceName (b n l) => HasSourceName (WithSrcB b n l) where + getSourceName (WithSrcB _ b) = getSourceName b + +instance HasSourceName (UAnnBinder n l) where + getSourceName (UAnnBinder _ b _ _) = getSourceName b -instance HasSourceName (UBinder c n l) where +instance HasSourceName (UBinder' c n l) where getSourceName = \case - UBindSource _ sn -> sn - UIgnore -> "_" - UBind _ sn _ -> sn + UBindSource sn -> sn + UIgnore -> "_" + UBind sn _ -> sn -- === Source context helpers === -data WithSrc a = WithSrc SrcPosCtx a +-- First SrcId is for the group itself. The rest are for keywords, symbols, etc. +data WithSrcs a = WithSrcs SrcId [SrcId] a + deriving (Show, Functor, Generic) + +data WithSrc a = WithSrc SrcId a deriving (Show, Functor, Generic) -data WithSrcE (a::E) (n::S) = WithSrcE SrcPosCtx (a n) +data WithSrcE (a::E) (n::S) = WithSrcE SrcId (a n) deriving (Show, Generic) -data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcPosCtx (binder n l) - deriving (Show, Data, Generic) +data WithSrcB (binder::B) (n::S) (l::S) = WithSrcB SrcId (binder n l) + deriving (Show, Generic) + +instance HasSrcId (WithSrc a ) where getSrcId (WithSrc sid _ ) = sid +instance HasSrcId (WithSrcs a ) where getSrcId (WithSrcs sid _ _) = sid +instance HasSrcId (WithSrcE e n ) where getSrcId (WithSrcE sid _ ) = sid +instance HasSrcId (WithSrcB b n l) where getSrcId (WithSrcB sid _ ) = sid + +instance HasSrcId (UAnnBinder n l) where + getSrcId (UAnnBinder _ b _ _) = getSrcId b + +class HasSrcPos withSrc a | withSrc -> a where + srcPos :: withSrc -> SrcId + withoutSrc :: withSrc -> a + +instance HasSrcPos (WithSrc (a:: *)) a where + srcPos (WithSrc pos _) = pos + withoutSrc (WithSrc _ x) = x -class HasSrcPos a where - srcPos :: a -> SrcPosCtx +instance HasSrcPos (WithSrcs (a:: *)) a where + srcPos (WithSrcs pos _ _) = pos + withoutSrc (WithSrcs _ _ x) = x -instance HasSrcPos (WithSrcE (a::E) (n::S)) where +instance HasSrcPos (WithSrcE (e::E) (n::S)) (e n) where srcPos (WithSrcE pos _) = pos + withoutSrc (WithSrcE _ x) = x -instance HasSrcPos (WithSrcB (b::B) (n::S) (n::S)) where +instance HasSrcPos (WithSrcB (b::B) (n::S) (l::S)) (b n l) where srcPos (WithSrcB pos _) = pos + withoutSrc (WithSrcB _ x) = x + +class FromSourceNameW a where + fromSourceNameW :: SourceNameW -> a + +instance FromSourceNameW (SourceNameOr a VoidS) where + fromSourceNameW (WithSrc sid x) = SourceName sid x + +instance FromSourceNameW (SourceOrInternalName c VoidS) where + fromSourceNameW x = SourceOrInternalName $ fromSourceNameW x + +instance FromSourceNameW (UBinder' s VoidS VoidS) where + fromSourceNameW x = UBindSource $ withoutSrc x + +instance FromSourceNameW (UPat' VoidS VoidS) where + fromSourceNameW = UPatBinder . fromSourceNameW + +instance FromSourceNameW (UAnnBinder VoidS VoidS) where + fromSourceNameW s = UAnnBinder Explicit (fromSourceNameW s) UNoAnn [] + +instance FromSourceNameW (UExpr' VoidS) where + fromSourceNameW = UVar . fromSourceNameW + +instance FromSourceNameW (a n) => FromSourceNameW (WithSrcE a n) where + fromSourceNameW x = WithSrcE (srcPos x) $ fromSourceNameW x + +instance FromSourceNameW (b n l) => FromSourceNameW (WithSrcB b n l) where + fromSourceNameW x = WithSrcB (srcPos x) $ fromSourceNameW x -- === SourceMap === @@ -481,11 +568,11 @@ data UModule = UModule -- === top-level blocks === data SourceBlock = SourceBlock - { sbLine :: Int - , sbOffset :: Int - , sbLogLevel :: LogLevel - , sbText :: Text - , sbContents :: SourceBlock' } + { sbLine :: Int + , sbOffset :: Int + , sbText :: Text + , sbLexemeInfo :: LexemeInfo + , sbContents :: SourceBlock' } deriving (Show, Generic) type ReachedEOF = Bool @@ -494,16 +581,16 @@ data SymbolicZeros = SymbolicZeros | InstantiateZeros deriving (Generic, Eq, Show) data SourceBlock' - = TopDecl CTopDecl - | Command CmdName Group - | DeclareForeign SourceName SourceName Group - | DeclareCustomLinearization SourceName SymbolicZeros Group + = TopDecl CTopDeclW + | Command CmdName GroupW + | DeclareForeign SourceNameW SourceNameW GroupW + | DeclareCustomLinearization SourceNameW SymbolicZeros GroupW | Misc SourceBlockMisc | UnParseable ReachedEOF String deriving (Show, Generic) data SourceBlockMisc - = GetNameType SourceName + = GetNameType SourceNameW | ImportModule ModuleSourceName | QueryEnv EnvQuery | ProseBlock Text @@ -514,10 +601,6 @@ data SourceBlockMisc data CmdName = GetType | EvalExpr OutFormat | ExportFun String deriving (Show, Generic) -data LogLevel = LogNothing | PrintEvalTime | PrintBench String - | LogPasses [PassName] | LogAll - deriving (Show, Generic) - data PrintBackend = PrintCodegen -- Soon-to-be default path based on `PrintAny` | PrintHaskell -- Backup path for debugging in case the codegen path breaks. @@ -529,7 +612,7 @@ data PrintBackend = data OutFormat = Printed (Maybe PrintBackend) | RenderHtml deriving (Show, Eq, Generic) -data PassName = Parse | RenamePass | TypePass | SynthPass | SimpPass | ImpPass | JitPass +data PassName = Parse | RenamePass | TypePass | SimpPass | ImpPass | JitPass | LLVMOpt | AsmPass | JAXPass | JAXSimpPass | LLVMEval | LowerOptPass | LowerPass | ResultPass | JaxprAndHLO | EarlyOptPass | OptPass | VectPass | OccAnalysisPass | InlinePass @@ -537,8 +620,7 @@ data PassName = Parse | RenamePass | TypePass | SynthPass | SimpPass | ImpPass | instance Show PassName where show p = case p of - Parse -> "parse" ; RenamePass -> "rename"; - TypePass -> "typed" ; SynthPass -> "synth" + Parse -> "parse" ; RenamePass -> "rename"; TypePass -> "typed" SimpPass -> "simp" ; ImpPass -> "imp" ; JitPass -> "llvm" LLVMOpt -> "llvmopt" ; AsmPass -> "asm" JAXPass -> "jax" ; JAXSimpPass -> "jsimp"; ResultPass -> "result" @@ -573,6 +655,102 @@ data PrimName = | UTuple -- overloaded for type constructor and data constructor, resolved in inference deriving (Show, Eq, Generic) +-- === primitive constructors and operators === + +strToPrimName :: String -> Maybe PrimName +strToPrimName s = M.lookup s primNames + +primNameToStr :: PrimName -> String +primNameToStr prim = case lookup prim $ map swap $ M.toList primNames of + Just s -> s + Nothing -> show prim + +showPrimName :: PrimName -> String +showPrimName prim = primNameToStr prim +{-# NOINLINE showPrimName #-} + +primNames :: M.Map String PrimName +primNames = M.fromList + [ ("ask" , UMAsk), ("mextend", UMExtend) + , ("get" , UMGet), ("put" , UMPut) + , ("while" , UWhile) + , ("linearize", ULinearize), ("linearTranspose", UTranspose) + , ("runReader", URunReader), ("runWriter" , URunWriter), ("runState", URunState) + , ("runIO" , URunIO ), ("catchException" , UCatchException) + , ("iadd" , binary IAdd), ("isub" , binary ISub) + , ("imul" , binary IMul), ("fdiv" , binary FDiv) + , ("fadd" , binary FAdd), ("fsub" , binary FSub) + , ("fmul" , binary FMul), ("idiv" , binary IDiv) + , ("irem" , binary IRem) + , ("fpow" , binary FPow) + , ("and" , binary BAnd), ("or" , binary BOr ) + , ("not" , unary BNot), ("xor" , binary BXor) + , ("shl" , binary BShL), ("shr" , binary BShR) + , ("ieq" , binary (ICmp Equal)), ("feq", binary (FCmp Equal)) + , ("igt" , binary (ICmp Greater)), ("fgt", binary (FCmp Greater)) + , ("ilt" , binary (ICmp Less)), ("flt", binary (FCmp Less)) + , ("fneg" , unary FNeg) + , ("exp" , unary Exp), ("exp2" , unary Exp2) + , ("log" , unary Log), ("log2" , unary Log2), ("log10" , unary Log10) + , ("sin" , unary Sin), ("cos" , unary Cos) + , ("tan" , unary Tan), ("sqrt" , unary Sqrt) + , ("floor", unary Floor), ("ceil" , unary Ceil), ("round", unary Round) + , ("log1p", unary Log1p), ("lgamma", unary LGamma) + , ("erf" , unary Erf), ("erfc" , unary Erfc) + , ("TyKind" , UPrimTC $ P.TypeKind) + , ("Float64" , baseTy $ Scalar Float64Type) + , ("Float32" , baseTy $ Scalar Float32Type) + , ("Int64" , baseTy $ Scalar Int64Type) + , ("Int32" , baseTy $ Scalar Int32Type) + , ("Word8" , baseTy $ Scalar Word8Type) + , ("Word32" , baseTy $ Scalar Word32Type) + , ("Word64" , baseTy $ Scalar Word64Type) + , ("Int32Ptr" , baseTy $ ptrTy $ Scalar Int32Type) + , ("Word8Ptr" , baseTy $ ptrTy $ Scalar Word8Type) + , ("Word32Ptr" , baseTy $ ptrTy $ Scalar Word32Type) + , ("Word64Ptr" , baseTy $ ptrTy $ Scalar Word64Type) + , ("Float32Ptr", baseTy $ ptrTy $ Scalar Float32Type) + , ("PtrPtr" , baseTy $ ptrTy $ ptrTy $ Scalar Word8Type) + , ("Nat" , UNat) + , ("Fin" , UFin) + , ("EffKind" , UEffectRowKind) + , ("NatCon" , UNatCon) + , ("Ref" , UPrimTC $ P.RefType) + , ("HeapType" , UPrimTC $ P.HeapType) + , ("indexRef" , UIndexRef) + , ("alloc" , memOp $ P.IOAlloc) + , ("free" , memOp $ P.IOFree) + , ("ptrOffset", memOp $ P.PtrOffset) + , ("ptrLoad" , memOp $ P.PtrLoad) + , ("ptrStore" , memOp $ P.PtrStore) + , ("throwError" , miscOp $ P.ThrowError) + , ("throwException", miscOp $ P.ThrowException) + , ("dataConTag" , miscOp $ P.SumTag) + , ("toEnum" , miscOp $ P.ToEnum) + , ("outputStream" , miscOp $ P.OutputStream) + , ("cast" , miscOp $ P.CastOp) + , ("bitcast" , miscOp $ P.BitcastOp) + , ("unsafeCoerce" , miscOp $ P.UnsafeCoerce) + , ("garbageVal" , miscOp $ P.GarbageVal) + , ("select" , miscOp $ P.Select) + , ("showAny" , miscOp $ P.ShowAny) + , ("showScalar" , miscOp $ P.ShowScalar) + , ("projNewtype" , UProjNewtype) + , ("applyMethod0" , UApplyMethod 0) + , ("applyMethod1" , UApplyMethod 1) + , ("applyMethod2" , UApplyMethod 2) + , ("explicitApply", UExplicitApply) + , ("monoLit", UMonoLiteral) + ] + where + binary op = UBinOp op + baseTy b = UBaseType b + memOp op = UMemOp op + unary op = UUnOp op + ptrTy ty = PtrType (CPU, ty) + miscOp op = UMiscOp op + + -- === instances === instance Semigroup (SourceMap n) where @@ -627,19 +805,16 @@ instance Pretty (SourceMap n) where fold [pretty v <+> "@>" <+> pretty x <> hardline | (v, x) <- M.toList m ] instance GenericE UVar where - type RepE UVar = EitherE8 (Name (AtomNameC CoreIR)) (Name TyConNameC) + type RepE UVar = EitherE6 (Name (AtomNameC CoreIR)) (Name TyConNameC) (Name DataConNameC) (Name ClassNameC) - (Name MethodNameC) (Name EffectNameC) - (Name EffectOpNameC) (Name TyConNameC) + (Name MethodNameC) (Name TyConNameC) fromE name = case name of UAtomVar v -> Case0 v UTyConVar v -> Case1 v UDataConVar v -> Case2 v UClassVar v -> Case3 v UMethodVar v -> Case4 v - UEffectVar v -> Case5 v - UEffectOpVar v -> Case6 v - UPunVar v -> Case7 v + UPunVar v -> Case5 v {-# INLINE fromE #-} toE name = case name of @@ -648,9 +823,8 @@ instance GenericE UVar where Case2 v -> UDataConVar v Case3 v -> UClassVar v Case4 v -> UMethodVar v - Case5 v -> UEffectVar v - Case6 v -> UEffectOpVar v - Case7 v -> UPunVar v + Case5 v -> UPunVar v + _ -> error "impossible" {-# INLINE toE #-} instance Pretty (UVar n) where @@ -660,8 +834,6 @@ instance Pretty (UVar n) where UDataConVar v -> "Data constructor name: " <> pretty v UClassVar v -> "Class name: " <> pretty v UMethodVar v -> "Method name: " <> pretty v - UEffectVar v -> "Effect name: " <> pretty v - UEffectOpVar v -> "Effect operation name: " <> pretty v UPunVar v -> "Shared type constructor / data constructor name: " <> pretty v -- TODO: name subst instances for the rest of UExpr @@ -683,42 +855,58 @@ instance HasNameHint ModuleSourceName where getNameHint Prelude = getNameHint @String "prelude" getNameHint Main = getNameHint @String "main" -instance HasNameHint (UBinder c n l) where +instance HasNameHint (UBinder' c n l) where getNameHint b = case b of - UBindSource _ v -> getNameHint v - UIgnore -> noHint - UBind _ v _ -> getNameHint v + UBindSource v -> getNameHint v + UIgnore -> noHint + UBind v _ -> getNameHint v -instance Color c => BindsNames (UBinder c) where - toScopeFrag (UBindSource _ _) = emptyOutFrag +instance Color c => BindsNames (UBinder' c) where + toScopeFrag (UBindSource _) = emptyOutFrag toScopeFrag (UIgnore) = emptyOutFrag - toScopeFrag (UBind _ _ b) = toScopeFrag b + toScopeFrag (UBind _ b) = toScopeFrag b -instance Color c => ProvesExt (UBinder c) where -instance Color c => BindsAtMostOneName (UBinder c) c where +instance Color c => ProvesExt (UBinder' c) where +instance Color c => BindsAtMostOneName (UBinder' c) c where b @> x = case b of - UBindSource _ _ -> emptyInFrag - UIgnore -> emptyInFrag - UBind _ _ b' -> b' @> x + UBindSource _ -> emptyInFrag + UIgnore -> emptyInFrag + UBind _ b' -> b' @> x -instance Color c => SinkableB (UBinder c) where +instance Color c => SinkableB (UBinder' c) where sinkingProofB _ _ _ = todoSinkableProof -instance Color c => RenameB (UBinder c) where +instance Color c => RenameB (UBinder' c) where renameB env ub cont = case ub of - UBindSource pos sn -> cont env $ UBindSource pos sn + UBindSource sn -> cont env $ UBindSource sn UIgnore -> cont env UIgnore - UBind ctx sn b -> renameB env b \env' b' -> cont env' $ UBind ctx sn b' + UBind sn b -> renameB env b \env' b' -> cont env' $ UBind sn b' + +instance SinkableB b => SinkableB (WithSrcB b) where + sinkingProofB _ _ _ = todoSinkableProof + +instance RenameB b => RenameB (WithSrcB b) where + renameB env (WithSrcB sid b) cont = + renameB env b \env' b' -> cont env' (WithSrcB sid b') -instance ProvesExt (UAnnBinder req) where -instance BindsNames (UAnnBinder req) where - toScopeFrag (UAnnBinder b _ _) = toScopeFrag b +instance ProvesExt b => ProvesExt (WithSrcB b) where + toExtEvidence (WithSrcB _ b) = toExtEvidence b -instance BindsAtMostOneName (UAnnBinder req) (AtomNameC CoreIR) where - UAnnBinder b _ _ @> x = b @> x +instance BindsNames b => BindsNames (WithSrcB b) where + toScopeFrag (WithSrcB _ b) = toScopeFrag b + +instance BindsAtMostOneName b r => BindsAtMostOneName (WithSrcB b) r where + WithSrcB _ b @> x = b @> x + +instance ProvesExt UAnnBinder where +instance BindsNames UAnnBinder where + toScopeFrag (UAnnBinder _ b _ _) = toScopeFrag b + +instance BindsAtMostOneName UAnnBinder (AtomNameC CoreIR) where + UAnnBinder _ b _ _ @> x = b @> x instance GenericE (WithSrcE e) where - type RepE (WithSrcE e) = PairE (LiftE SrcPosCtx) e + type RepE (WithSrcE e) = PairE (LiftE SrcId) e fromE (WithSrcE ctx x) = PairE (LiftE ctx) x toE (PairE (LiftE ctx) x) = WithSrcE ctx x @@ -743,7 +931,6 @@ instance Ord SourceBlock where compare x y = compare (sbText x) (sbText y) instance Store SymbolicZeros -instance Store LogLevel instance Store PassName instance Store ModuleSourceName instance Store (UVar n) @@ -752,37 +939,7 @@ instance Store (SourceMap n) instance Hashable ModuleSourceName -instance Store SourceName' -instance Hashable SourceName' - -instance IsString SourceName' where - fromString = SourceName' emptySrcPosCtx - -instance IsString (SourceNameOr a VoidS) where - fromString = SourceName emptySrcPosCtx - -instance IsString (SourceOrInternalName c VoidS) where - fromString = SISourceName - -instance IsString (UBinder s VoidS VoidS) where - fromString = UBindSource emptySrcPosCtx - -instance IsString (UPat' VoidS VoidS) where - fromString = UPatBinder . fromString - -instance IsString (UOptAnnBinder VoidS VoidS) where - fromString s = UAnnBinder (fromString s) UNoAnn [] - -instance IsString (UExpr' VoidS) where - fromString = UVar . fromString - -instance IsString (a n) => IsString (WithSrcE a n) where - fromString = WithSrcE emptySrcPosCtx . fromString - -instance IsString (b n l) => IsString (WithSrcB b n l) where - fromString = WithSrcB emptySrcPosCtx . fromString - -deriving instance Show (UBinder s n l) +deriving instance Show (UBinder' s n l) deriving instance Show (UDataDefTrail n) deriving instance Show (ULamExpr n) deriving instance Show (UPiExpr n) @@ -802,3 +959,265 @@ deriving instance Ord (UEffect n) deriving instance Show (UEffectRow n) deriving instance Eq (UEffectRow n) deriving instance Ord (UEffectRow n) + +instance ToJSON LexemeType + +-- === Pretty instances === + +instance Pretty CSBlock where + pretty (IndentedBlock _ decls) = nest 2 $ prettyLines decls + pretty (ExprBlock g) = pArg g + +instance Pretty Group where pretty = prettyFromPrettyPrec +instance PrettyPrec Group where + prettyPrec = undefined + -- prettyPrec (CIdentifier n) = atPrec ArgPrec $ fromString n + -- prettyPrec (CPrim prim args) = prettyOpDefault prim args + -- prettyPrec (CParens blk) = + -- atPrec ArgPrec $ "(" <> p blk <> ")" + -- prettyPrec (CBrackets g) = atPrec ArgPrec $ pretty g + -- prettyPrec (CBin op lhs rhs) = + -- atPrec LowestPrec $ pArg lhs <+> p op <+> pArg rhs + -- prettyPrec (CLambda args body) = + -- atPrec LowestPrec $ "\\" <> spaced args <> "." <> p body + -- prettyPrec (CCase scrut alts) = + -- atPrec LowestPrec $ "case " <> p scrut <> " of " <> prettyLines alts + -- prettyPrec g = atPrec ArgPrec $ fromString $ show g + +instance Pretty Bin where + pretty = \case + EvalBinOp name -> pretty name + DepAmpersand -> "&>" + Dot -> "." + DepComma -> ",>" + Colon -> ":" + DoubleColon -> "::" + Dollar -> "$" + ImplicitArrow -> "->>" + FatArrow -> "->>" + Pipe -> "|" + CSEqual -> "=" + +instance Pretty SourceBlock' where + pretty (TopDecl decl) = pretty decl + pretty d = fromString $ show d + +instance Pretty CTopDecl where + pretty (CSDecl ann decl) = annDoc <> pretty decl + where annDoc = case ann of + PlainLet -> mempty + _ -> pretty ann <> " " + pretty d = fromString $ show d + +instance Pretty CSDecl where + pretty = undefined + -- pretty (CLet pat blk) = pArg pat <+> "=" <+> p blk + -- pretty (CBind pat blk) = pArg pat <+> "<-" <+> p blk + -- pretty (CDefDecl (CDef name args maybeAnn blk)) = + -- "def " <> fromString name <> " " <> prettyParamGroups args <+> annDoc + -- <> nest 2 (hardline <> p blk) + -- where annDoc = case maybeAnn of Just (expl, ty) -> p expl <+> pArg ty + -- Nothing -> mempty + -- pretty (CInstance header givens methods name) = + -- name' <> p header <> p givens <> nest 2 (hardline <> p methods) where + -- name' = case name of + -- Nothing -> "instance " + -- (Just n) -> "named-instance " <> p n <> " " + -- pretty (CExpr e) = p e + +instance Pretty PrimName where + pretty primName = pretty $ "%" ++ showPrimName primName + +instance Pretty (UDataDefTrail n) where + pretty (UDataDefTrail bs) = pretty $ unsafeFromNest bs + +instance Pretty (UAnnBinder n l) where + pretty (UAnnBinder _ b ty _) = pretty b <> ":" <> pretty ty + +instance Pretty (UAnn n) where + pretty (UAnn ty) = ":" <> pretty ty + pretty UNoAnn = mempty + +instance Pretty (UMethodDef' n) where + pretty (UMethodDef b rhs) = pretty b <+> "=" <+> pretty rhs + +instance Pretty (UPat' n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPat' n l) where + prettyPrec pat = case pat of + UPatBinder x -> atPrec ArgPrec $ p x + UPatProd xs -> atPrec ArgPrec $ parens $ commaSep (unsafeFromNest xs) + UPatDepPair (PairB x y) -> atPrec ArgPrec $ parens $ p x <> ",> " <> p y + UPatCon con pats -> atPrec AppPrec $ parens $ p con <+> spaced (unsafeFromNest pats) + UPatTable pats -> atPrec ArgPrec $ p pats + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UAlt n) where + pretty (UAlt pat body) = pretty pat <+> "->" <+> pretty body + +instance Pretty (UTopDecl n l) where + pretty = \case + UDataDefDecl (UDataDef nm bs dataCons) bTyCon bDataCons -> + "data" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines (zip (toList $ unsafeFromNest bDataCons) dataCons)) + UStructDecl bTyCon (UStructDef nm bs fields defs) -> + "struct" <+> p bTyCon <+> p nm <+> spaced (unsafeFromNest bs) <+> "where" <> nest 2 + (prettyLines fields <> prettyLines defs) + UInterface params methodTys interfaceName methodNames -> + "interface" <+> p params <+> p interfaceName + <> hardline <> foldMap (<>hardline) methods + where + methods = [ p b <> ":" <> p (unsafeCoerceE ty) + | (b, ty) <- zip (toList $ unsafeFromNest methodNames) methodTys] + UInstance className bs params methods (RightB UnitB) _ -> + "instance" <+> p bs <+> p className <+> spaced params <+> + prettyLines methods + UInstance className bs params methods (LeftB v) _ -> + "named-instance" <+> p v <+> ":" <+> p bs <+> p className <+> p params + <> prettyLines methods + ULocalDecl decl -> p decl + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty (UDecl' n l) where + pretty = \case + ULet ann b _ rhs -> align $ pretty ann <+> pretty b <+> "=" <> (nest 2 $ group $ line <> pLowest rhs) + UExprDecl expr -> pretty expr + UPass -> "pass" + +instance Pretty (UEffectRow n) where + pretty (UEffectRow x Nothing) = encloseSep "<" ">" "," $ (pretty <$> toList x) + pretty (UEffectRow x (Just y)) = "{" <> (hsep $ punctuate "," (pretty <$> toList x)) <+> "|" <+> pretty y <> "}" + +instance Pretty e => Pretty (WithSrcs e) where pretty (WithSrcs _ _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrcs e) where prettyPrec (WithSrcs _ _ x) = prettyPrec x + +instance Pretty e => Pretty (WithSrc e) where pretty (WithSrc _ x) = pretty x +instance PrettyPrec e => PrettyPrec (WithSrc e) where prettyPrec (WithSrc _ x) = prettyPrec x + +instance PrettyE e => Pretty (WithSrcE e n) where pretty (WithSrcE _ x) = pretty x +instance PrettyPrecE e => PrettyPrec (WithSrcE e n) where prettyPrec (WithSrcE _ x) = prettyPrec x + +instance PrettyB b => Pretty (WithSrcB b n l) where pretty (WithSrcB _ x) = pretty x +instance PrettyPrecB b => PrettyPrec (WithSrcB b n l) where prettyPrec (WithSrcB _ x) = prettyPrec x + +instance PrettyE e => Pretty (SourceNameOr e n) where + pretty (SourceName _ v) = pretty v + pretty (InternalName _ v _) = pretty v + +instance Pretty (SourceOrInternalName c n) where + pretty (SourceOrInternalName sn) = pretty sn + +instance Pretty (ULamExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (ULamExpr n) where + prettyPrec (ULamExpr bs _ _ _ body) = atPrec LowestPrec $ + "\\" <> pretty bs <+> "." <+> indented (pretty body) + +instance Pretty (UPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UPiExpr n) where + prettyPrec (UPiExpr pats appExpl UPure ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pLowest ty + prettyPrec (UPiExpr pats appExpl eff ty) = atPrec LowestPrec $ align $ + pretty pats <+> pretty appExpl <+> pretty eff <+> pLowest ty + +instance Pretty (UTabPiExpr n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UTabPiExpr n) where + prettyPrec (UTabPiExpr pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "=>" <+> pLowest ty + +instance Pretty (UDepPairType n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UDepPairType n) where + -- TODO: print explicitness info + prettyPrec (UDepPairType _ pat ty) = atPrec LowestPrec $ align $ + pretty pat <+> "&>" <+> pLowest ty + +instance Pretty (UBlock' n) where + pretty (UBlock decls result) = + prettyLines (unsafeFromNest decls) <> hardline <> pLowest result + +instance Pretty (UExpr' n) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UExpr' n) where + prettyPrec expr = case expr of + ULit l -> prettyPrec l + UVar v -> atPrec ArgPrec $ p v + ULam lam -> prettyPrec lam + UApp f xs named -> atPrec AppPrec $ pAppArg (pApp f) xs <+> p named + UTabApp f x -> atPrec AppPrec $ pArg f <> "." <> pArg x + UFor dir (UForExpr binder body) -> + atPrec LowestPrec $ kw <+> p binder <> "." + <+> nest 2 (p body) + where kw = case dir of Fwd -> "for" + Rev -> "rof" + UPi piType -> prettyPrec piType + UTabPi piType -> prettyPrec piType + UDepPairTy depPairType -> prettyPrec depPairType + UDepPair lhs rhs -> atPrec ArgPrec $ parens $ + p lhs <+> ",>" <+> p rhs + UHole -> atPrec ArgPrec "_" + UTypeAnn v ty -> atPrec LowestPrec $ + group $ pApp v <> line <> ":" <+> pApp ty + UTabCon xs -> atPrec ArgPrec $ p xs + UPrim prim xs -> atPrec AppPrec $ p (show prim) <+> p xs + UCase e alts -> atPrec LowestPrec $ "case" <+> p e <> + nest 2 (prettyLines alts) + UFieldAccess x (WithSrc _ f) -> atPrec AppPrec $ p x <> "~" <> p f + UNatLit v -> atPrec ArgPrec $ p v + UIntLit v -> atPrec ArgPrec $ p v + UFloatLit v -> atPrec ArgPrec $ p v + UDo block -> atPrec LowestPrec $ p block + where + p :: Pretty a => a -> Doc ann + p = pretty + +instance Pretty SourceBlock where + pretty block = pretty $ ensureNewline (sbText block) where + -- Force the SourceBlock to end in a newline for echoing, even if + -- it was terminated with EOF in the original program. + ensureNewline t = case unsnoc t of + Nothing -> t + Just (_, '\n') -> t + _ -> t `snoc` '\n' + +instance Pretty Output where + pretty = \case + TextOut s -> pretty s + HtmlOut _ -> "" + SourceInfo _ -> "" + PassInfo _ s -> pretty s + MiscLog s -> pretty s + Error e -> pretty e + +instance Pretty PassName where + pretty x = pretty $ show x + +instance Pretty Result where + pretty (Result (Outputs outs) r) = vcat (map pretty outs) <> maybeErr + where maybeErr = case r of Failure err -> pretty err + Success () -> mempty + +instance Pretty (UBinder' c n l) where pretty = prettyFromPrettyPrec +instance PrettyPrec (UBinder' c n l) where + prettyPrec b = atPrec ArgPrec case b of + UBindSource v -> pretty v + UIgnore -> "_" + UBind v _ -> pretty v + +instance Pretty FieldName' where + pretty = \case + FieldName s -> pretty s + FieldNum n -> pretty n + +instance Pretty (UEffect n) where + pretty eff = case eff of + URWSEffect rws h -> pretty rws <+> pretty h + UExceptionEffect -> "Except" + UIOEffect -> "IO" + +prettyOpDefault :: PrettyPrec a => PrimName -> [a] -> DocPrec ann +prettyOpDefault name args = + case length args of + 0 -> atPrec ArgPrec primName + _ -> atPrec AppPrec $ pAppArg primName args + where primName = pretty name diff --git a/src/lib/Types/Top.hs b/src/lib/Types/Top.hs new file mode 100644 index 000000000..b67fe357f --- /dev/null +++ b/src/lib/Types/Top.hs @@ -0,0 +1,1035 @@ +-- Copyright 2022 Google LLC +-- +-- Use of this source code is governed by a BSD-style +-- license that can be found in the LICENSE file or at +-- https://developers.google.com/open-source/licenses/bsd + +{-# LANGUAGE UndecidableInstances #-} +{-# LANGUAGE StrictData #-} + +-- Top-level data types + +module Types.Top where + +import Data.Functor ((<&>)) +import Data.Hashable +import Data.Text.Prettyprint.Doc +import qualified Data.Map.Strict as M +import qualified Data.Set as S + +import GHC.Generics (Generic (..)) +import Data.Store (Store (..)) +import Foreign.Ptr + +import Name +import Util (FileHash, SnocList (..)) +import IRVariants +import PPrint + +import Types.Primitives +import Types.Core +import Types.Source +import Types.Imp + +type TopBlock = TopLam -- used for nullary lambda +type IsDestLam = Bool +data TopLam (r::IR) (n::S) = TopLam IsDestLam (PiType r n) (LamExpr r n) + deriving (Show, Generic) +type STopLam = TopLam SimpIR +type CTopLam = TopLam CoreIR + +data EvalStatus a = Waiting | Running | Finished a + deriving (Show, Eq, Ord, Generic, Functor, Foldable, Traversable) +type TopFunEvalStatus n = EvalStatus (TopFunLowerings n) + +data TopFun (n::S) = + DexTopFun (TopFunDef n) (TopLam SimpIR n) (TopFunEvalStatus n) + | FFITopFun String IFunType + deriving (Show, Generic) + +data TopFunDef (n::S) = + Specialization (SpecializationSpec n) + | LinearizationPrimal (LinearizationSpec n) + -- Tangent functions all take some number of nonlinear args, then a *single* + -- linear arg. This is so that transposition can be an involution - you apply + -- it twice and you get back to the original function. + | LinearizationTangent (LinearizationSpec n) + deriving (Show, Generic) + +newtype TopFunLowerings (n::S) = TopFunLowerings + { topFunObjCode :: FunObjCodeName n } -- TODO: add optimized, imp etc. as needed + deriving (Show, Generic, SinkableE, HoistableE, RenameE, AlphaEqE, AlphaHashableE, Pretty) + +data AtomBinding (r::IR) (n::S) where + LetBound :: DeclBinding r n -> AtomBinding r n + MiscBound :: Type r n -> AtomBinding r n + TopDataBound :: RepVal n -> AtomBinding SimpIR n + SolverBound :: SolverBinding n -> AtomBinding CoreIR n + NoinlineFun :: CType n -> CAtom n -> AtomBinding CoreIR n + FFIFunBound :: CorePiType n -> TopFunName n -> AtomBinding CoreIR n + +deriving instance IRRep r => Show (AtomBinding r n) +deriving via WrapE (AtomBinding r) n instance IRRep r => Generic (AtomBinding r n) + +data SolverBinding (n::S) = + InfVarBound (CType n) + | SkolemBound (CType n) + | DictBound (CType n) + deriving (Show, Generic) + +-- TODO: Use an IntMap +newtype CustomRules (n::S) = + CustomRules { customRulesMap :: M.Map (AtomName CoreIR n) (AtomRules n) } + deriving (Semigroup, Monoid, Store) +data AtomRules (n::S) = + -- number of implicit args, number of explicit args, linearization function + CustomLinearize Int Int SymbolicZeros (CAtom n) + deriving (Generic) + +-- === envs and modules === + +-- `ModuleEnv` contains data that only makes sense in the context of evaluating +-- a particular module. `TopEnv` contains everything that makes sense "between" +-- evaluating modules. +data Env n = Env + { topEnv :: {-# UNPACK #-} TopEnv n + , moduleEnv :: {-# UNPACK #-} ModuleEnv n } + deriving (Generic) + +newtype EnvFrag (n::S) (l::S) = EnvFrag (RecSubstFrag Binding n l) + deriving (OutFrag) + +data TopEnv (n::S) = TopEnv + { envDefs :: RecSubst Binding n + , envCustomRules :: CustomRules n + , envCache :: Cache n + , envLoadedModules :: LoadedModules n + , envLoadedObjects :: LoadedObjects n } + deriving (Generic) + +data SerializedEnv n = SerializedEnv + { serializedEnvDefs :: RecSubst Binding n + , serializedEnvCustomRules :: CustomRules n + , serializedEnvCache :: Cache n } + deriving (Generic) + +-- TODO: consider splitting this further into `ModuleEnv` (the env that's +-- relevant between top-level decls) and `LocalEnv` (the additional parts of the +-- env that's relevant under a lambda binder). Unlike the Top/Module +-- distinction, there's some overlap. For example, instances can be defined at +-- both the module-level and local level. Similarly, if we start allowing +-- top-level effects in `Main` then we'll have module-level effects and local +-- effects. +data ModuleEnv (n::S) = ModuleEnv + { envImportStatus :: ImportStatus n + , envSourceMap :: SourceMap n + , envSynthCandidates :: SynthCandidates n } + deriving (Generic) + +data Module (n::S) = Module + { moduleSourceName :: ModuleSourceName + , moduleDirectDeps :: S.Set (ModuleName n) + , moduleTransDeps :: S.Set (ModuleName n) -- XXX: doesn't include the module itself + , moduleExports :: SourceMap n + -- these are just the synth candidates required by this + -- module by itself. We'll usually also need those required by the module's + -- (transitive) dependencies, which must be looked up separately. + , moduleSynthCandidates :: SynthCandidates n } + deriving (Show, Generic) + +data LoadedModules (n::S) = LoadedModules + { fromLoadedModules :: M.Map ModuleSourceName (ModuleName n)} + deriving (Show, Generic) + +emptyModuleEnv :: ModuleEnv n +emptyModuleEnv = ModuleEnv emptyImportStatus (SourceMap mempty) mempty + +emptyLoadedModules :: LoadedModules n +emptyLoadedModules = LoadedModules mempty + +data LoadedObjects (n::S) = LoadedObjects + -- the pointer points to the actual runtime function + { fromLoadedObjects :: M.Map (FunObjCodeName n) NativeFunction} + deriving (Show, Generic) + +emptyLoadedObjects :: LoadedObjects n +emptyLoadedObjects = LoadedObjects mempty + +data ImportStatus (n::S) = ImportStatus + { directImports :: S.Set (ModuleName n) + -- XXX: This are cached for efficiency. It's derivable from `directImports`. + , transImports :: S.Set (ModuleName n) } + deriving (Show, Generic) + +data TopEnvFrag n l = TopEnvFrag (EnvFrag n l) (ModuleEnv l) (SnocList (TopEnvUpdate l)) + +data TopEnvUpdate n = + ExtendCache (Cache n) + | AddCustomRule (CAtomName n) (AtomRules n) + | UpdateLoadedModules ModuleSourceName (ModuleName n) + | UpdateLoadedObjects (FunObjCodeName n) NativeFunction + | FinishDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | LowerDictSpecialization (SpecDictName n) [TopLam SimpIR n] + | UpdateTopFunEvalStatus (TopFunName n) (TopFunEvalStatus n) + | UpdateInstanceDef (InstanceName n) (InstanceDef n) + | UpdateTyConDef (TyConName n) (TyConDef n) + | UpdateFieldDef (TyConName n) SourceName (CAtomName n) + +-- TODO: we could add a lot more structure for querying by dict type, caching, etc. +data SynthCandidates n = SynthCandidates + { instanceDicts :: M.Map (ClassName n) [InstanceName n] + , ixInstances :: [InstanceName n] } + deriving (Show, Generic) + +emptyImportStatus :: ImportStatus n +emptyImportStatus = ImportStatus mempty mempty + +-- TODO: figure out the additional top-level context we need -- backend, other +-- compiler flags etc. We can have a map from those to this. + +data Cache (n::S) = Cache + { specializationCache :: EMap SpecializationSpec TopFunName n + , ixDictCache :: EMap AbsDict SpecDictName n + , linearizationCache :: EMap LinearizationSpec (PairE TopFunName TopFunName) n + , transpositionCache :: EMap TopFunName TopFunName n + -- This is memoizing `parseAndGetDeps :: Text -> [ModuleSourceName]`. But we + -- only want to store one entry per module name as a simple cache eviction + -- policy, so we store it keyed on the module name, with the text hash for + -- the validity check. + , parsedDeps :: M.Map ModuleSourceName (FileHash, [ModuleSourceName]) + , moduleEvaluations :: M.Map ModuleSourceName ((FileHash, [ModuleName n]), ModuleName n) + } deriving (Show, Generic) + +-- === runtime function and variable representations === + +type RuntimeEnv = DynamicVarKeyPtrs + +type DexDestructor = FunPtr (IO ()) + +data NativeFunction = NativeFunction + { nativeFunPtr :: FunPtr () + , nativeFunTeardown :: IO () } + +instance Show NativeFunction where + show _ = "" + +-- Holds pointers to thread-local storage used to simulate dynamically scoped +-- variables, such as the output stream file descriptor. +type DynamicVarKeyPtrs = [(DynamicVar, Ptr ())] + +data DynamicVar = OutStreamDyvar -- TODO: add others as needed + deriving (Enum, Bounded) + +dynamicVarCName :: DynamicVar -> String +dynamicVarCName OutStreamDyvar = "dex_out_stream_dyvar" + +dynamicVarLinkMap :: DynamicVarKeyPtrs -> [(String, Ptr ())] +dynamicVarLinkMap dyvars = dyvars <&> \(v, ptr) -> (dynamicVarCName v, ptr) + +-- === Specialization and generalization === + +type Generalized (r::IR) (e::E) (n::S) = (Abstracted r e n, [Atom r n]) +type Abstracted (r::IR) (e::E) = Abs (Nest (Binder r)) e +type AbsDict = Abstracted CoreIR (Dict CoreIR) + +data SpecializedDictDef n = + SpecializedDict + (AbsDict n) + -- Methods (thunked if nullary), if they're available. + -- We create specialized dict names during simplification, but we don't + -- actually simplify/lower them until we return to TopLevel + (Maybe [TopLam SimpIR n]) + deriving (Show, Generic) + +-- TODO: extend with AD-oriented specializations, backend-specific specializations etc. +data SpecializationSpec (n::S) = + AppSpecialization (AtomVar CoreIR n) (Abstracted CoreIR (ListE CAtom) n) + deriving (Show, Generic) + +type Active = Bool +data LinearizationSpec (n::S) = LinearizationSpec (TopFunName n) [Active] + deriving (Show, Generic) + +-- === bindings - static information we carry about a lexical scope === + +-- TODO: consider making this an open union via a typeable-like class +data Binding (c::C) (n::S) where + AtomNameBinding :: AtomBinding r n -> Binding (AtomNameC r) n + TyConBinding :: Maybe (TyConDef n) -> DotMethods n -> Binding TyConNameC n + DataConBinding :: TyConName n -> Int -> Binding DataConNameC n + ClassBinding :: ClassDef n -> Binding ClassNameC n + InstanceBinding :: InstanceDef n -> CorePiType n -> Binding InstanceNameC n + MethodBinding :: ClassName n -> Int -> Binding MethodNameC n + TopFunBinding :: TopFun n -> Binding TopFunNameC n + FunObjCodeBinding :: CFunction n -> Binding FunObjCodeNameC n + ModuleBinding :: Module n -> Binding ModuleNameC n + -- TODO: add a case for abstracted pointers, as used in `ClosedImpFunction` + PtrBinding :: PtrType -> PtrLitVal -> Binding PtrNameC n + SpecializedDictBinding :: SpecializedDictDef n -> Binding SpecializedDictNameC n + ImpNameBinding :: BaseType -> Binding ImpNameC n + +-- === ToBinding === + +atomBindingToBinding :: AtomBinding r n -> Binding (AtomNameC r) n +atomBindingToBinding b = AtomNameBinding b + +bindingToAtomBinding :: Binding (AtomNameC r) n -> AtomBinding r n +bindingToAtomBinding (AtomNameBinding b) = b + +class (RenameE e, SinkableE e) => ToBinding (e::E) (c::C) | e -> c where + toBinding :: e n -> Binding c n + +instance Color c => ToBinding (Binding c) c where + toBinding = id + +instance IRRep r => ToBinding (AtomBinding r) (AtomNameC r) where + toBinding = atomBindingToBinding + +instance IRRep r => ToBinding (DeclBinding r) (AtomNameC r) where + toBinding = toBinding . LetBound + +instance IRRep r => ToBinding (Type r) (AtomNameC r) where + toBinding = toBinding . MiscBound + +instance ToBinding SolverBinding (AtomNameC CoreIR) where + toBinding = toBinding . SolverBound + +instance IRRep r => ToBinding (IxType r) (AtomNameC r) where + toBinding (IxType t _) = toBinding t + +instance (ToBinding e1 c, ToBinding e2 c) => ToBinding (EitherE e1 e2) c where + toBinding (LeftE e) = toBinding e + toBinding (RightE e) = toBinding e + +instance ToBindersAbs (TopLam r) (Expr r) r where + toAbs (TopLam _ _ lam) = toAbs lam + +-- === GenericE, GenericB === + +instance GenericE SpecializedDictDef where + type RepE SpecializedDictDef = AbsDict `PairE` MaybeE (ListE (TopLam SimpIR)) + fromE (SpecializedDict ab methods) = ab `PairE` methods' + where methods' = case methods of Just xs -> LeftE (ListE xs) + Nothing -> RightE UnitE + {-# INLINE fromE #-} + toE (ab `PairE` methods) = SpecializedDict ab methods' + where methods' = case methods of LeftE (ListE xs) -> Just xs + RightE UnitE -> Nothing + {-# INLINE toE #-} + +instance SinkableE SpecializedDictDef +instance HoistableE SpecializedDictDef +instance AlphaEqE SpecializedDictDef +instance AlphaHashableE SpecializedDictDef +instance RenameE SpecializedDictDef + +instance HasScope Env where + toScope = toScope . envDefs . topEnv + +instance OutMap Env where + emptyOutMap = + Env (TopEnv (RecSubst emptyInFrag) mempty mempty emptyLoadedModules emptyLoadedObjects) + emptyModuleEnv + {-# INLINE emptyOutMap #-} + +instance ExtOutMap Env (RecSubstFrag Binding) where + -- TODO: We might want to reorganize this struct to make this + -- do less explicit sinking etc. It's a hot operation! + extendOutMap (Env (TopEnv defs rules cache loadedM loadedO) moduleEnv) frag = + withExtEvidence frag $ Env + (TopEnv + (defs `extendRecSubst` frag) + (sink rules) + (sink cache) + (sink loadedM) + (sink loadedO)) + (sink moduleEnv) + {-# INLINE extendOutMap #-} + +instance ExtOutMap Env EnvFrag where + extendOutMap = extendEnv + {-# INLINE extendOutMap #-} + +extendEnv :: Distinct l => Env n -> EnvFrag n l -> Env l +extendEnv env (EnvFrag newEnv) = do + case extendOutMap env newEnv of + Env envTop (ModuleEnv imports sm scs) -> do + Env envTop (ModuleEnv imports sm scs) +{-# NOINLINE [1] extendEnv #-} + + +instance GenericE AtomRules where + type RepE AtomRules = (LiftE (Int, Int, SymbolicZeros)) `PairE` CAtom + fromE (CustomLinearize ni ne sz a) = LiftE (ni, ne, sz) `PairE` a + toE (LiftE (ni, ne, sz) `PairE` a) = CustomLinearize ni ne sz a +instance SinkableE AtomRules +instance HoistableE AtomRules +instance AlphaEqE AtomRules +instance RenameE AtomRules + +instance GenericE CustomRules where + type RepE CustomRules = ListE (PairE (AtomName CoreIR) AtomRules) + fromE (CustomRules m) = ListE $ toPairE <$> M.toList m + toE (ListE l) = CustomRules $ M.fromList $ fromPairE <$> l +instance SinkableE CustomRules +instance HoistableE CustomRules +instance AlphaEqE CustomRules +instance RenameE CustomRules + +instance GenericE Cache where + type RepE Cache = + EMap SpecializationSpec TopFunName + `PairE` EMap AbsDict SpecDictName + `PairE` EMap LinearizationSpec (PairE TopFunName TopFunName) + `PairE` EMap TopFunName TopFunName + `PairE` LiftE (M.Map ModuleSourceName (FileHash, [ModuleSourceName])) + `PairE` ListE ( LiftE ModuleSourceName + `PairE` LiftE FileHash + `PairE` ListE ModuleName + `PairE` ModuleName) + fromE (Cache x y z w parseCache evalCache) = + x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` + ListE [LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + | (sourceName, ((hashVal, deps), result)) <- M.toList evalCache ] + {-# INLINE fromE #-} + toE (x `PairE` y `PairE` z `PairE` w `PairE` LiftE parseCache `PairE` ListE evalCache) = + Cache x y z w parseCache + (M.fromList + [(sourceName, ((hashVal, deps), result)) + | LiftE sourceName `PairE` LiftE hashVal `PairE` ListE deps `PairE` result + <- evalCache]) + {-# INLINE toE #-} + +instance SinkableE Cache +instance HoistableE Cache +instance AlphaEqE Cache +instance RenameE Cache +instance Store (Cache n) + +instance Monoid (Cache n) where + mempty = Cache mempty mempty mempty mempty mempty mempty + mappend = (<>) + +instance Semigroup (Cache n) where + -- right-biased instead of left-biased + Cache x1 x2 x3 x4 x5 x6 <> Cache y1 y2 y3 y4 y5 y6 = + Cache (y1<>x1) (y2<>x2) (y3<>x3) (y4<>x4) (x5<>y5) (x6<>y6) + + +instance GenericE SynthCandidates where + type RepE SynthCandidates = ListE (PairE ClassName (ListE InstanceName)) + `PairE` ListE InstanceName + fromE (SynthCandidates xs ys) = ListE xs' `PairE` ListE ys + where xs' = map (\(k,vs) -> PairE k (ListE vs)) (M.toList xs) + {-# INLINE fromE #-} + toE (ListE xs `PairE` ListE ys) = SynthCandidates xs' ys + where xs' = M.fromList $ map (\(PairE k (ListE vs)) -> (k,vs)) xs + {-# INLINE toE #-} + +instance SinkableE SynthCandidates +instance HoistableE SynthCandidates +instance AlphaEqE SynthCandidates +instance AlphaHashableE SynthCandidates +instance RenameE SynthCandidates + +instance IRRep r => GenericE (AtomBinding r) where + type RepE (AtomBinding r) = + EitherE2 (EitherE3 + (DeclBinding r) -- LetBound + (Type r) -- MiscBound + (WhenCore r SolverBinding) -- SolverBound + ) (EitherE3 + (WhenCore r (PairE CType CAtom)) -- NoinlineFun + (WhenSimp r RepVal) -- TopDataBound + (WhenCore r (CorePiType `PairE` TopFunName)) -- FFIFunBound + ) + + fromE = \case + LetBound x -> Case0 $ Case0 x + MiscBound x -> Case0 $ Case1 x + SolverBound x -> Case0 $ Case2 $ WhenIRE x + NoinlineFun t x -> Case1 $ Case0 $ WhenIRE $ PairE t x + TopDataBound repVal -> Case1 $ Case1 $ WhenIRE repVal + FFIFunBound ty v -> Case1 $ Case2 $ WhenIRE $ ty `PairE` v + {-# INLINE fromE #-} + + toE = \case + Case0 x' -> case x' of + Case0 x -> LetBound x + Case1 x -> MiscBound x + Case2 (WhenIRE x) -> SolverBound x + _ -> error "impossible" + Case1 x' -> case x' of + Case0 (WhenIRE (PairE t x)) -> NoinlineFun t x + Case1 (WhenIRE repVal) -> TopDataBound repVal + Case2 (WhenIRE (ty `PairE` v)) -> FFIFunBound ty v + _ -> error "impossible" + _ -> error "impossible" + {-# INLINE toE #-} + + +instance IRRep r => SinkableE (AtomBinding r) +instance IRRep r => HoistableE (AtomBinding r) +instance IRRep r => RenameE (AtomBinding r) +instance IRRep r => AlphaEqE (AtomBinding r) +instance IRRep r => AlphaHashableE (AtomBinding r) + +instance GenericE TopFunDef where + type RepE TopFunDef = EitherE3 SpecializationSpec LinearizationSpec LinearizationSpec + fromE = \case + Specialization s -> Case0 s + LinearizationPrimal s -> Case1 s + LinearizationTangent s -> Case2 s + {-# INLINE fromE #-} + toE = \case + Case0 s -> Specialization s + Case1 s -> LinearizationPrimal s + Case2 s -> LinearizationTangent s + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE TopFunDef +instance HoistableE TopFunDef +instance RenameE TopFunDef +instance AlphaEqE TopFunDef +instance AlphaHashableE TopFunDef + +instance IRRep r => GenericE (TopLam r) where + type RepE (TopLam r) = LiftE Bool `PairE` PiType r `PairE` LamExpr r + fromE (TopLam d x y) = LiftE d `PairE` x `PairE` y + {-# INLINE fromE #-} + toE (LiftE d `PairE` x `PairE` y) = TopLam d x y + {-# INLINE toE #-} + +instance IRRep r => SinkableE (TopLam r) +instance IRRep r => HoistableE (TopLam r) +instance IRRep r => RenameE (TopLam r) +instance IRRep r => AlphaEqE (TopLam r) +instance IRRep r => AlphaHashableE (TopLam r) + +instance GenericE TopFun where + type RepE TopFun = EitherE + (TopFunDef `PairE` TopLam SimpIR `PairE` ComposeE EvalStatus TopFunLowerings) + (LiftE (String, IFunType)) + fromE = \case + DexTopFun def lam status -> LeftE (def `PairE` lam `PairE` ComposeE status) + FFITopFun name ty -> RightE (LiftE (name, ty)) + {-# INLINE fromE #-} + toE = \case + LeftE (def `PairE` lam `PairE` ComposeE status) -> DexTopFun def lam status + RightE (LiftE (name, ty)) -> FFITopFun name ty + {-# INLINE toE #-} + +instance SinkableE TopFun +instance HoistableE TopFun +instance RenameE TopFun +instance AlphaEqE TopFun +instance AlphaHashableE TopFun + +instance GenericE SpecializationSpec where + type RepE SpecializationSpec = + PairE (AtomVar CoreIR) (Abs (Nest (Binder CoreIR)) (ListE CAtom)) + fromE (AppSpecialization fname (Abs bs args)) = PairE fname (Abs bs args) + {-# INLINE fromE #-} + toE (PairE fname (Abs bs args)) = AppSpecialization fname (Abs bs args) + {-# INLINE toE #-} + +instance HasNameHint (SpecializationSpec n) where + getNameHint (AppSpecialization f _) = getNameHint f + +instance SinkableE SpecializationSpec +instance HoistableE SpecializationSpec +instance RenameE SpecializationSpec +instance AlphaEqE SpecializationSpec +instance AlphaHashableE SpecializationSpec + +instance GenericE LinearizationSpec where + type RepE LinearizationSpec = PairE TopFunName (LiftE [Active]) + fromE (LinearizationSpec fname actives) = PairE fname (LiftE actives) + {-# INLINE fromE #-} + toE (PairE fname (LiftE actives)) = LinearizationSpec fname actives + {-# INLINE toE #-} + +instance SinkableE LinearizationSpec +instance HoistableE LinearizationSpec +instance RenameE LinearizationSpec +instance AlphaEqE LinearizationSpec +instance AlphaHashableE LinearizationSpec + +instance GenericE SolverBinding where + type RepE SolverBinding = EitherE3 + CType + CType + CType + fromE = \case + InfVarBound ty -> Case0 ty + SkolemBound ty -> Case1 ty + DictBound ty -> Case2 ty + {-# INLINE fromE #-} + + toE = \case + Case0 ty -> InfVarBound ty + Case1 ty -> SkolemBound ty + Case2 ty -> DictBound ty + _ -> error "impossible" + {-# INLINE toE #-} + +instance SinkableE SolverBinding +instance HoistableE SolverBinding +instance RenameE SolverBinding +instance AlphaEqE SolverBinding +instance AlphaHashableE SolverBinding + +instance GenericE (Binding c) where + type RepE (Binding c) = + EitherE3 + (EitherE6 + (WhenAtomName c AtomBinding) + (WhenC TyConNameC c (MaybeE TyConDef `PairE` DotMethods)) + (WhenC DataConNameC c (TyConName `PairE` LiftE Int)) + (WhenC ClassNameC c (ClassDef)) + (WhenC InstanceNameC c (InstanceDef `PairE` CorePiType)) + (WhenC MethodNameC c (ClassName `PairE` LiftE Int))) + (EitherE4 + (WhenC TopFunNameC c (TopFun)) + (WhenC FunObjCodeNameC c (CFunction)) + (WhenC ModuleNameC c (Module)) + (WhenC PtrNameC c (LiftE (PtrType, PtrLitVal)))) + (EitherE2 + (WhenC SpecializedDictNameC c (SpecializedDictDef)) + (WhenC ImpNameC c (LiftE BaseType))) + + fromE = \case + AtomNameBinding binding -> Case0 $ Case0 $ WhenAtomName binding + TyConBinding dataDef methods -> Case0 $ Case1 $ WhenC $ toMaybeE dataDef `PairE` methods + DataConBinding dataDefName idx -> Case0 $ Case2 $ WhenC $ dataDefName `PairE` LiftE idx + ClassBinding classDef -> Case0 $ Case3 $ WhenC $ classDef + InstanceBinding instanceDef ty -> Case0 $ Case4 $ WhenC $ instanceDef `PairE` ty + MethodBinding className idx -> Case0 $ Case5 $ WhenC $ className `PairE` LiftE idx + TopFunBinding fun -> Case1 $ Case0 $ WhenC $ fun + FunObjCodeBinding cFun -> Case1 $ Case1 $ WhenC $ cFun + ModuleBinding m -> Case1 $ Case2 $ WhenC $ m + PtrBinding ty p -> Case1 $ Case3 $ WhenC $ LiftE (ty,p) + SpecializedDictBinding def -> Case2 $ Case0 $ WhenC $ def + ImpNameBinding ty -> Case2 $ Case1 $ WhenC $ LiftE ty + {-# INLINE fromE #-} + + toE = \case + Case0 (Case0 (WhenAtomName binding)) -> AtomNameBinding binding + Case0 (Case1 (WhenC (def `PairE` methods))) -> TyConBinding (fromMaybeE def) methods + Case0 (Case2 (WhenC (n `PairE` LiftE idx))) -> DataConBinding n idx + Case0 (Case3 (WhenC (classDef))) -> ClassBinding classDef + Case0 (Case4 (WhenC (instanceDef `PairE` ty))) -> InstanceBinding instanceDef ty + Case0 (Case5 (WhenC ((n `PairE` LiftE i)))) -> MethodBinding n i + Case1 (Case0 (WhenC (fun))) -> TopFunBinding fun + Case1 (Case1 (WhenC (f))) -> FunObjCodeBinding f + Case1 (Case2 (WhenC (m))) -> ModuleBinding m + Case1 (Case3 (WhenC ((LiftE (ty,p))))) -> PtrBinding ty p + Case2 (Case0 (WhenC (def))) -> SpecializedDictBinding def + Case2 (Case1 (WhenC ((LiftE ty)))) -> ImpNameBinding ty + _ -> error "impossible" + {-# INLINE toE #-} + +deriving via WrapE (Binding c) n instance Generic (Binding c n) +instance SinkableV Binding +instance HoistableV Binding +instance RenameV Binding +instance Color c => SinkableE (Binding c) +instance Color c => HoistableE (Binding c) +instance Color c => RenameE (Binding c) + +instance Semigroup (SynthCandidates n) where + SynthCandidates xs ys <> SynthCandidates xs' ys' = + SynthCandidates (M.unionWith (<>) xs xs') (ys <> ys') + +instance Monoid (SynthCandidates n) where + mempty = SynthCandidates mempty mempty + +instance GenericB EnvFrag where + type RepB EnvFrag = RecSubstFrag Binding + fromB (EnvFrag frag) = frag + toB frag = EnvFrag frag + +instance SinkableB EnvFrag +instance HoistableB EnvFrag +instance ProvesExt EnvFrag +instance BindsNames EnvFrag +instance RenameB EnvFrag + +instance GenericE TopEnvUpdate where + type RepE TopEnvUpdate = EitherE2 ( + EitherE4 + {- ExtendCache -} Cache + {- AddCustomRule -} (CAtomName `PairE` AtomRules) + {- UpdateLoadedModules -} (LiftE ModuleSourceName `PairE` ModuleName) + {- UpdateLoadedObjects -} (FunObjCodeName `PairE` LiftE NativeFunction) + ) ( EitherE6 + {- FinishDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- LowerDictSpecialization -} (SpecDictName `PairE` ListE (TopLam SimpIR)) + {- UpdateTopFunEvalStatus -} (TopFunName `PairE` ComposeE EvalStatus TopFunLowerings) + {- UpdateInstanceDef -} (InstanceName `PairE` InstanceDef) + {- UpdateTyConDef -} (TyConName `PairE` TyConDef) + {- UpdateFieldDef -} (TyConName `PairE` LiftE SourceName `PairE` CAtomName) + ) + fromE = \case + ExtendCache x -> Case0 $ Case0 x + AddCustomRule x y -> Case0 $ Case1 (x `PairE` y) + UpdateLoadedModules x y -> Case0 $ Case2 (LiftE x `PairE` y) + UpdateLoadedObjects x y -> Case0 $ Case3 (x `PairE` LiftE y) + FinishDictSpecialization x y -> Case1 $ Case0 (x `PairE` ListE y) + LowerDictSpecialization x y -> Case1 $ Case1 (x `PairE` ListE y) + UpdateTopFunEvalStatus x y -> Case1 $ Case2 (x `PairE` ComposeE y) + UpdateInstanceDef x y -> Case1 $ Case3 (x `PairE` y) + UpdateTyConDef x y -> Case1 $ Case4 (x `PairE` y) + UpdateFieldDef x y z -> Case1 $ Case5 (x `PairE` LiftE y `PairE` z) + + toE = \case + Case0 e -> case e of + Case0 x -> ExtendCache x + Case1 (x `PairE` y) -> AddCustomRule x y + Case2 (LiftE x `PairE` y) -> UpdateLoadedModules x y + Case3 (x `PairE` LiftE y) -> UpdateLoadedObjects x y + _ -> error "impossible" + Case1 e -> case e of + Case0 (x `PairE` ListE y) -> FinishDictSpecialization x y + Case1 (x `PairE` ListE y) -> LowerDictSpecialization x y + Case2 (x `PairE` ComposeE y) -> UpdateTopFunEvalStatus x y + Case3 (x `PairE` y) -> UpdateInstanceDef x y + Case4 (x `PairE` y) -> UpdateTyConDef x y + Case5 (x `PairE` LiftE y `PairE` z) -> UpdateFieldDef x y z + _ -> error "impossible" + _ -> error "impossible" + +instance SinkableE TopEnvUpdate +instance HoistableE TopEnvUpdate +instance RenameE TopEnvUpdate + +instance GenericB TopEnvFrag where + type RepB TopEnvFrag = PairB EnvFrag (LiftB (ModuleEnv `PairE` ListE TopEnvUpdate)) + fromB (TopEnvFrag x y (ReversedList z)) = PairB x (LiftB (y `PairE` ListE z)) + toB (PairB x (LiftB (y `PairE` ListE z))) = TopEnvFrag x y (ReversedList z) + +instance RenameB TopEnvFrag +instance HoistableB TopEnvFrag +instance SinkableB TopEnvFrag +instance ProvesExt TopEnvFrag +instance BindsNames TopEnvFrag + +instance OutFrag TopEnvFrag where + emptyOutFrag = TopEnvFrag emptyOutFrag mempty mempty + {-# INLINE emptyOutFrag #-} + catOutFrags (TopEnvFrag frag1 env1 partial1) + (TopEnvFrag frag2 env2 partial2) = + withExtEvidence frag2 $ + TopEnvFrag + (catOutFrags frag1 frag2) + (sink env1 <> env2) + (sinkSnocList partial1 <> partial2) + {-# INLINE catOutFrags #-} + +-- XXX: unlike `ExtOutMap Env EnvFrag` instance, this once doesn't +-- extend the synthesis candidates based on the annotated let-bound names. It +-- only extends synth candidates when they're supplied explicitly. +instance ExtOutMap Env TopEnvFrag where + extendOutMap env (TopEnvFrag (EnvFrag frag) mEnv' otherUpdates) = do + let newerTopEnv = foldl applyUpdate newTopEnv otherUpdates + Env newerTopEnv newModuleEnv + where + Env (TopEnv defs rules cache loadedM loadedO) mEnv = env + + newTopEnv = withExtEvidence frag $ TopEnv + (defs `extendRecSubst` frag) + (sink rules) (sink cache) (sink loadedM) (sink loadedO) + + newModuleEnv = + ModuleEnv + (imports <> imports') + (sm <> sm' <> newImportedSM) + (scs <> scs' <> newImportedSC) + where + ModuleEnv imports sm scs = withExtEvidence frag $ sink mEnv + ModuleEnv imports' sm' scs' = mEnv' + newDirectImports = S.difference (directImports imports') (directImports imports) + newTransImports = S.difference (transImports imports') (transImports imports) + newImportedSM = flip foldMap newDirectImports $ moduleExports . lookupModulePure + newImportedSC = flip foldMap newTransImports $ moduleSynthCandidates . lookupModulePure + + lookupModulePure v = case lookupEnvPure newTopEnv v of ModuleBinding m -> m + +applyUpdate :: TopEnv n -> TopEnvUpdate n -> TopEnv n +applyUpdate e = \case + ExtendCache cache -> e { envCache = envCache e <> cache} + AddCustomRule x y -> e { envCustomRules = envCustomRules e <> CustomRules (M.singleton x y)} + UpdateLoadedModules x y -> e { envLoadedModules = envLoadedModules e <> LoadedModules (M.singleton x y)} + UpdateLoadedObjects x y -> e { envLoadedObjects = envLoadedObjects e <> LoadedObjects (M.singleton x y)} + FinishDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs oldMethods) = lookupEnvPure e dName + case oldMethods of + Nothing -> do + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + Just _ -> error "shouldn't be adding methods if we already have them" + LowerDictSpecialization dName methods -> do + let SpecializedDictBinding (SpecializedDict dAbs _) = lookupEnvPure e dName + let newBinding = SpecializedDictBinding $ SpecializedDict dAbs (Just methods) + updateEnv dName newBinding e + UpdateTopFunEvalStatus f s -> do + case lookupEnvPure e f of + TopFunBinding (DexTopFun def lam _) -> + updateEnv f (TopFunBinding $ DexTopFun def lam s) e + _ -> error "can't update ffi function impl" + UpdateInstanceDef name def -> do + case lookupEnvPure e name of + InstanceBinding _ ty -> updateEnv name (InstanceBinding def ty) e + UpdateTyConDef name def -> do + let TyConBinding _ methods = lookupEnvPure e name + updateEnv name (TyConBinding (Just def) methods) e + UpdateFieldDef name sn x -> do + let TyConBinding def methods = lookupEnvPure e name + updateEnv name (TyConBinding def (methods <> DotMethods (M.singleton sn x))) e + +updateEnv :: Color c => Name c n -> Binding c n -> TopEnv n -> TopEnv n +updateEnv v rhs env = + env { envDefs = RecSubst $ updateSubstFrag v rhs bs } + where (RecSubst bs) = envDefs env + +lookupEnvPure :: Color c => TopEnv n -> Name c n -> Binding c n +lookupEnvPure env v = lookupTerminalSubstFrag (fromRecSubst $ envDefs $ env) v + +instance GenericE Module where + type RepE Module = LiftE ModuleSourceName + `PairE` ListE ModuleName + `PairE` ListE ModuleName + `PairE` SourceMap + `PairE` SynthCandidates + + fromE (Module name deps transDeps sm sc) = + LiftE name `PairE` ListE (S.toList deps) `PairE` ListE (S.toList transDeps) + `PairE` sm `PairE` sc + {-# INLINE fromE #-} + + toE (LiftE name `PairE` ListE deps `PairE` ListE transDeps + `PairE` sm `PairE` sc) = + Module name (S.fromList deps) (S.fromList transDeps) sm sc + {-# INLINE toE #-} + +instance SinkableE Module +instance HoistableE Module +instance AlphaEqE Module +instance AlphaHashableE Module +instance RenameE Module + +instance GenericE ImportStatus where + type RepE ImportStatus = ListE ModuleName `PairE` ListE ModuleName + fromE (ImportStatus direct trans) = ListE (S.toList direct) + `PairE` ListE (S.toList trans) + {-# INLINE fromE #-} + toE (ListE direct `PairE` ListE trans) = + ImportStatus (S.fromList direct) (S.fromList trans) + {-# INLINE toE #-} + +instance SinkableE ImportStatus +instance HoistableE ImportStatus +instance AlphaEqE ImportStatus +instance AlphaHashableE ImportStatus +instance RenameE ImportStatus + +instance Semigroup (ImportStatus n) where + ImportStatus direct trans <> ImportStatus direct' trans' = + ImportStatus (direct <> direct') (trans <> trans') + +instance Monoid (ImportStatus n) where + mappend = (<>) + mempty = ImportStatus mempty mempty + +instance GenericE LoadedModules where + type RepE LoadedModules = ListE (PairE (LiftE ModuleSourceName) ModuleName) + fromE (LoadedModules m) = + ListE $ M.toList m <&> \(v,md) -> PairE (LiftE v) md + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedModules $ M.fromList $ pairs <&> \(PairE (LiftE v) md) -> (v, md) + {-# INLINE toE #-} + +instance SinkableE LoadedModules +instance HoistableE LoadedModules +instance AlphaEqE LoadedModules +instance AlphaHashableE LoadedModules +instance RenameE LoadedModules + +instance GenericE LoadedObjects where + type RepE LoadedObjects = ListE (PairE FunObjCodeName (LiftE NativeFunction)) + fromE (LoadedObjects m) = + ListE $ M.toList m <&> \(v,p) -> PairE v (LiftE p) + {-# INLINE fromE #-} + toE (ListE pairs) = + LoadedObjects $ M.fromList $ pairs <&> \(PairE v (LiftE p)) -> (v, p) + {-# INLINE toE #-} + +instance SinkableE LoadedObjects +instance HoistableE LoadedObjects +instance RenameE LoadedObjects + +instance GenericE ModuleEnv where + type RepE ModuleEnv = ImportStatus + `PairE` SourceMap + `PairE` SynthCandidates + fromE (ModuleEnv imports sm sc) = imports `PairE` sm `PairE` sc + {-# INLINE fromE #-} + toE (imports `PairE` sm `PairE` sc) = ModuleEnv imports sm sc + {-# INLINE toE #-} + +instance SinkableE ModuleEnv +instance HoistableE ModuleEnv +instance AlphaEqE ModuleEnv +instance AlphaHashableE ModuleEnv +instance RenameE ModuleEnv + +instance Semigroup (ModuleEnv n) where + ModuleEnv x1 x2 x3 <> ModuleEnv y1 y2 y3 = + ModuleEnv (x1<>y1) (x2<>y2) (x3<>y3) + +instance Monoid (ModuleEnv n) where + mempty = ModuleEnv mempty mempty mempty + +instance Semigroup (LoadedModules n) where + LoadedModules m1 <> LoadedModules m2 = LoadedModules (m2 <> m1) + +instance Monoid (LoadedModules n) where + mempty = LoadedModules mempty + +instance Semigroup (LoadedObjects n) where + LoadedObjects m1 <> LoadedObjects m2 = LoadedObjects (m2 <> m1) + +instance Monoid (LoadedObjects n) where + mempty = LoadedObjects mempty + + +-- === instance === + +prettyRecord :: [(String, Doc ann)] -> Doc ann +prettyRecord xs = foldMap (\(name, val) -> pretty name <> indented val) xs + +instance Pretty (TopEnv n) where + pretty (TopEnv defs rules cache _ _) = + prettyRecord [ ("Defs" , pretty defs) + , ("Rules" , pretty rules) + , ("Cache" , pretty cache) ] + +instance Pretty (CustomRules n) where + pretty _ = "TODO: Rule printing" + +instance Pretty (ImportStatus n) where + pretty imports = pretty $ S.toList $ directImports imports + +instance Pretty (ModuleEnv n) where + pretty (ModuleEnv imports sm sc) = + prettyRecord [ ("Imports" , pretty imports) + , ("Source map" , pretty sm) + , ("Synth candidates", pretty sc) ] + +instance Pretty (Env n) where + pretty (Env env1 env2) = + prettyRecord [ ("Top env" , pretty env1) + , ("Module env", pretty env2)] + +instance Pretty (SolverBinding n) where + pretty (InfVarBound ty) = "Inference variable of type:" <+> pretty ty + pretty (SkolemBound ty) = "Skolem variable of type:" <+> pretty ty + pretty (DictBound ty) = "Dictionary variable of type:" <+> pretty ty + +instance Pretty (Binding c n) where + pretty b = case b of + -- using `unsafeCoerceIRE` here because otherwise we don't have `IRRep` + -- TODO: can we avoid printing needing IRRep? Presumably it's related to + -- manipulating sets or something, which relies on Eq/Ord, which relies on renaming. + AtomNameBinding info -> "Atom name:" <+> pretty (unsafeCoerceIRE @CoreIR info) + TyConBinding dataDef _ -> "Type constructor: " <+> pretty dataDef + DataConBinding tyConName idx -> "Data constructor:" <+> + pretty tyConName <+> "Constructor index:" <+> pretty idx + ClassBinding classDef -> pretty classDef + InstanceBinding instanceDef _ -> pretty instanceDef + MethodBinding className idx -> "Method" <+> pretty idx <+> "of" <+> pretty className + TopFunBinding f -> pretty f + FunObjCodeBinding _ -> "" + ModuleBinding _ -> "" + PtrBinding _ _ -> "" + SpecializedDictBinding _ -> "" + ImpNameBinding ty -> "Imp name of type: " <+> pretty ty + +instance Pretty (Module n) where + pretty m = prettyRecord + [ ("moduleSourceName" , pretty $ moduleSourceName m) + , ("moduleDirectDeps" , pretty $ S.toList $ moduleDirectDeps m) + , ("moduleTransDeps" , pretty $ S.toList $ moduleTransDeps m) + , ("moduleExports" , pretty $ moduleExports m) + , ("moduleSynthCandidates", pretty $ moduleSynthCandidates m) ] + +instance Pretty a => Pretty (EvalStatus a) where + pretty = \case + Waiting -> "" + Running -> "" + Finished a -> pretty a + +instance Pretty (EnvFrag n l) where + pretty (EnvFrag bindings) = pretty bindings + +instance Pretty (Cache n) where + pretty (Cache _ _ _ _ _ _) = "" -- TODO + +instance Pretty (SynthCandidates n) where + pretty scs = "instance dicts:" <+> pretty (M.toList $ instanceDicts scs) + +instance Pretty (LoadedModules n) where + pretty _ = "" + +instance Pretty (TopFunDef n) where + pretty = \case + Specialization s -> pretty s + LinearizationPrimal _ -> "" + LinearizationTangent _ -> "" + +instance Pretty (TopFun n) where + pretty = \case + DexTopFun def lam lowering -> + "Top-level Function" + <> hardline <+> "definition:" <+> pretty def + <> hardline <+> "lambda:" <+> pretty lam + <> hardline <+> "lowering:" <+> pretty lowering + FFITopFun f _ -> pretty f + +instance IRRep r => Pretty (TopLam r n) where + pretty (TopLam _ _ lam) = pretty lam + +instance IRRep r => Pretty (AtomBinding r n) where + pretty binding = case binding of + LetBound b -> pretty b + MiscBound t -> pretty t + SolverBound b -> pretty b + FFIFunBound s _ -> pretty s + NoinlineFun ty _ -> "Top function with type: " <+> pretty ty + TopDataBound (RepVal ty _) -> "Top data with type: " <+> pretty ty + +instance Pretty (SpecializationSpec n) where + pretty (AppSpecialization f (Abs bs (ListE args))) = + "Specialization" <+> pretty f <+> pretty bs <+> pretty args + +instance Hashable a => Hashable (EvalStatus a) + +instance Store (SolverBinding n) +instance IRRep r => Store (AtomBinding r n) +instance IRRep r => Store (TopLam r n) +instance Store (SynthCandidates n) +instance Store (Module n) +instance Store (ImportStatus n) +instance Store (TopFunLowerings n) +instance Store a => Store (EvalStatus a) +instance Store (TopFun n) +instance Store (TopFunDef n) +instance Color c => Store (Binding c n) +instance Store (ModuleEnv n) +instance Store (SerializedEnv n) +instance Store (AtomRules n) +instance Store (LinearizationSpec n) +instance Store (SpecializedDictDef n) +instance Store (SpecializationSpec n) diff --git a/src/lib/Util.hs b/src/lib/Util.hs index 7f9a89859..4dbc43edc 100644 --- a/src/lib/Util.hs +++ b/src/lib/Util.hs @@ -12,19 +12,21 @@ import Prelude import qualified Data.Set as Set import qualified Data.Map.Strict as M import Control.Applicative +import Control.Monad.Reader import Control.Monad.State.Strict import System.CPUTime import GHC.Base (getTag) import GHC.Exts ((==#), tagToEnum#) import Crypto.Hash import Data.Functor.Identity (Identity(..)) -import Data.Maybe (catMaybes) +import Data.Maybe (catMaybes, mapMaybe) import Data.List (sort) import Data.Hashable (Hashable) import Data.Store (Store) import qualified Data.List.NonEmpty as NE import qualified Data.ByteString as BS import Data.Foldable +import Data.Text.Prettyprint.Doc (Pretty (..), pretty) import Data.List.NonEmpty (NonEmpty (..)) import GHC.Generics (Generic) @@ -133,12 +135,8 @@ mapFst f zs = [(f x, y) | (x, y) <- zs] mapSnd :: (a -> b) -> [(c, a)] -> [(c, b)] mapSnd f zs = [(x, f y) | (x, y) <- zs] -mapMaybe :: (a -> Maybe b) -> [a] -> [b] -mapMaybe _ [] = [] -mapMaybe f (x:xs) = let rest = mapMaybe f xs - in case f x of - Just y -> y : rest - Nothing -> rest +foldJusts :: Monoid b => [a] -> (a -> Maybe b) -> b +foldJusts xs f = fold $ mapMaybe f xs forMFilter :: Monad m => [a] -> (a -> m (Maybe b)) -> m [b] forMFilter xs f = catMaybes <$> mapM f xs @@ -306,7 +304,7 @@ getAlternative xs = asum $ map pure xs {-# INLINE getAlternative #-} newtype SnocList a = ReversedList { fromReversedList :: [a] } - deriving Functor -- XXX: NOT deriving order-sensitive things like Monoid, Applicative etc + deriving (Show, Eq, Ord, Generic, Functor) -- XXX: NOT deriving order-sensitive things like Monoid, Applicative etc instance Semigroup (SnocList a) where (ReversedList x) <> (ReversedList y) = ReversedList $ y ++ x @@ -320,6 +318,10 @@ instance Foldable SnocList where foldMap f (ReversedList xs) = foldMap f (reverse xs) {-# INLINE foldMap #-} +instance Traversable SnocList where + traverse f (ReversedList xs) = ReversedList . reverse <$> traverse f (reverse xs) + {-# INLINE traverse #-} + snoc :: SnocList a -> a -> SnocList a snoc (ReversedList xs) x = ReversedList (x:xs) {-# INLINE snoc #-} @@ -353,6 +355,11 @@ zipTrees (Leaf x) (Leaf y) = Leaf (x, y) zipTrees (Branch xs) (Branch ys) | length xs == length ys = Branch $ zipWith zipTrees xs ys zipTrees _ _ = error "zip error" +instance Pretty a => Pretty (Tree a) where + pretty = \case + Leaf x -> pretty x + Branch xs -> pretty xs + -- === bytestrings paired with their hash digest === -- TODO: use something other than a string to store the digest diff --git a/src/lib/Vectorize.hs b/src/lib/Vectorize.hs index 88a6ef48e..daa606fc6 100644 --- a/src/lib/Vectorize.hs +++ b/src/lib/Vectorize.hs @@ -9,7 +9,7 @@ module Vectorize (vectorizeLoops) where import Prelude hiding ((.)) import Data.Word import Data.Functor -import Data.Text.Prettyprint.Doc (Pretty, pretty, viaShow, (<+>)) +import Data.Text.Prettyprint.Doc (viaShow) import Control.Category import Control.Monad.Reader import Control.Monad.State.Strict @@ -26,6 +26,7 @@ import Subst import PPrint import QueryType import Types.Core +import Types.Top import Types.OpNames qualified as P import Types.Primitives import Util (allM, zipWithZ) @@ -85,13 +86,13 @@ newtype TopVectorizeM (i::S) (o::S) (a:: *) = TopVectorizeM SubstReaderT Name (ReaderT1 CommuteMap (ReaderT1 (LiftE Word32) - (StateT1 (LiftE Errs) (BuilderT SimpIR FallibleM)))) i o a } + (StateT1 (LiftE [Err]) (BuilderT SimpIR Except)))) i o a } deriving ( Functor, Applicative, Monad, MonadFail, MonadReader (CommuteMap o) - , MonadState (LiftE Errs o), Fallible, ScopeReader, EnvReader + , MonadState (LiftE [Err] o), Fallible, ScopeReader, EnvReader , EnvExtender, Builder SimpIR, ScopableBuilder SimpIR, Catchable , SubstReader Name) -vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, Errs) +vectorizeLoops :: EnvReader m => Word32 -> STopLam n -> m n (STopLam n, [Err]) vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do case popNest bsDestB of Just (PairB bs b) -> @@ -102,33 +103,24 @@ vectorizeLoops width (TopLam d ty (LamExpr bsDestB body)) = liftEnvReaderM do {-# SCC vectorizeLoops #-} liftTopVectorizeM :: (EnvReader m) - => Word32 -> TopVectorizeM i i a -> m i (a, Errs) + => Word32 -> TopVectorizeM i i a -> m i (a, [Err]) liftTopVectorizeM vectorByteWidth action = do fallible <- liftBuilderT $ flip runStateT1 mempty $ runReaderT1 (LiftE vectorByteWidth) $ runReaderT1 mempty $ runSubstReaderT idSubst $ runTopVectorizeM action - case runFallibleM fallible of + case fallible of -- The failure case should not occur: vectorization errors should have been -- caught inside `vectorizeLoopsDecls` (and should have been added to the - -- `Errs` state of the `StateT` instance that is run with `runStateT` above). + -- `Err` state of the `StateT` instance that is run with `runStateT` above). Failure errs -> error $ pprint errs Success (a, (LiftE errs)) -> return $ (a, errs) -addVectErrCtx :: Fallible m => String -> String -> m a -> m a -addVectErrCtx name payload m = - let ctx = mempty { messageCtx = ["In `" ++ name ++ "`:\n" ++ payload] } - in addErrCtx ctx m - throwVectErr :: Fallible m => String -> m a -throwVectErr msg = throwErr (Err MiscErr mempty msg) - -prependCtxToErrs :: ErrCtx -> Errs -> Errs -prependCtxToErrs ctx (Errs errs) = - Errs $ map (\(Err ty ctx' msg) -> Err ty (ctx <> ctx') msg) errs +throwVectErr msg = throwInternal msg askVectorByteWidth :: TopVectorizeM i o Word32 -askVectorByteWidth = TopVectorizeM $ SubstReaderT $ lift $ lift11 (fromLiftE <$> ask) +askVectorByteWidth = TopVectorizeM $ liftSubstReaderT $ lift11 (fromLiftE <$> ask) extendCommuteMap :: AtomName SimpIR o -> MonoidCommutes -> TopVectorizeM i o a -> TopVectorizeM i o a extendCommuteMap name commutativity = local $ insertNameMapE name $ LiftE commutativity @@ -139,26 +131,21 @@ vectorizeLoopsDestBlock (Abs (destb:>destTy) body) = do destTy' <- renameM destTy withFreshBinder (getNameHint destb) destTy' \destb' -> do extendRenamer (destb @> binderName destb') do - Abs destb' <$> buildBlock (vectorizeLoopsBlock body) - -vectorizeLoopsBlock :: (Emits o) - => Block SimpIR i -> TopVectorizeM i o (SAtom o) -vectorizeLoopsBlock (Abs decls ans) = - vectorizeLoopsDecls decls $ renameM ans + Abs destb' <$> buildBlock (vectorizeLoopsExpr body) vectorizeLoopsDecls :: (Emits o) => Nest SDecl i i' -> TopVectorizeM i' o a -> TopVectorizeM i o a vectorizeLoopsDecls nest cont = case nest of Empty -> cont - Nest (Let b (DeclBinding ann expr)) rest -> do - v <- emitDecl (getNameHint b) ann =<< vectorizeLoopsExpr expr + Nest (Let b (DeclBinding _ expr)) rest -> do + v <- emitToVar =<< vectorizeLoopsExpr expr extendSubst (b @> atomVarName v) $ vectorizeLoopsDecls rest cont vectorizeLoopsLamExpr :: LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) vectorizeLoopsLamExpr (LamExpr bs body) = case bs of - Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsBlock body) + Empty -> LamExpr Empty <$> buildBlock (vectorizeLoopsExpr body) Nest (b:>ty) rest -> do ty' <- renameM ty withFreshBinder (getNameHint b) ty' \b' -> do @@ -166,12 +153,13 @@ vectorizeLoopsLamExpr (LamExpr bs body) = case bs of LamExpr bs' body' <- vectorizeLoopsLamExpr $ LamExpr rest body return $ LamExpr (Nest b' bs') body' -vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) +vectorizeLoopsExpr :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) vectorizeLoopsExpr expr = do vectorByteWidth <- askVectorByteWidth narrowestTypeByteWidth <- getNarrowestTypeByteWidth =<< renameM expr let loopWidth = vectorByteWidth `div` narrowestTypeByteWidth case expr of + Block _ (Abs decls body) -> vectorizeLoopsDecls decls $ vectorizeLoopsExpr body PrimOp (DAMOp (Seq effs dir ixty dest body)) -> do sz <- simplifyIxSize =<< renameM ixty case sz of @@ -183,14 +171,10 @@ vectorizeLoopsExpr expr = do let vn = n `div` loopWidth body' <- vectorizeSeq loopWidth ixty body dest' <- renameM dest - seqOp <- mkSeq dir (IxType IdxRepTy (IxDictRawFin (IdxRepVal vn))) dest' body' - return $ PrimOp $ DAMOp seqOp) - else renameM expr) - `catchErr` \errs -> do - let msg = "In `vectorizeLoopsDecls`:\nExpr:\n" ++ pprint expr - ctx = mempty { messageCtx = [msg] } - errs' = prependCtxToErrs ctx errs - modify (<> LiftE errs') + emit =<< mkSeq dir (IxType IdxRepTy (DictCon (IxRawFin (IdxRepVal vn)))) dest' body') + else renameM expr >>= emit) + `catchErr` \err -> do + modify (\(LiftE errs) -> LiftE (err:errs)) recurSeq expr _ -> recurSeq expr PrimOp (Hof (TypedHof _ (RunReader item (BinaryLamExpr hb' refb' body)))) -> do @@ -199,8 +183,8 @@ vectorizeLoopsExpr expr = do lam <- buildEffLam noHint itemTy \hb refb -> extendRenamer (hb' @> atomVarName hb) do extendRenamer (refb' @> atomVarName refb) do - vectorizeLoopsBlock body - PrimOp . Hof <$> mkTypedHof (RunReader item' lam) + vectorizeLoopsExpr body + emit =<< mkTypedHof (RunReader item' lam) PrimOp (Hof (TypedHof (EffTy _ ty) (RunWriter (Just dest) monoid (BinaryLamExpr hb' refb' body)))) -> do dest' <- renameM dest @@ -211,24 +195,24 @@ vectorizeLoopsExpr expr = do extendRenamer (hb' @> atomVarName hb) do extendRenamer (refb' @> atomVarName refb) do extendCommuteMap (atomVarName hb) commutativity do - vectorizeLoopsBlock body - PrimOp . Hof <$> mkTypedHof (RunWriter (Just dest') monoid' lam) - _ -> renameM expr + vectorizeLoopsExpr body + emit =<< mkTypedHof (RunWriter (Just dest') monoid' lam) + _ -> renameM expr >>= emit where - recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SExpr o) + recurSeq :: (Emits o) => SExpr i -> TopVectorizeM i o (SAtom o) recurSeq (PrimOp (DAMOp (Seq effs dir ixty dest body))) = do effs' <- renameM effs ixty' <- renameM ixty dest' <- renameM dest body' <- vectorizeLoopsLamExpr body - return $ PrimOp $ DAMOp $ Seq effs' dir ixty' dest' body' + emit $ Seq effs' dir ixty' dest' body' recurSeq _ = error "Impossible" simplifyIxSize :: (EnvReader m, ScopableBuilder SimpIR m) => IxType SimpIR n -> m n (Maybe Word32) simplifyIxSize ixty = do sizeMethod <- buildBlock $ applyIxMethod (sink $ ixTypeDict ixty) Size [] - cheapReduce sizeMethod >>= \case + reduceExpr sizeMethod >>= \case Just (IdxRepVal n) -> return $ Just n _ -> return Nothing {-# INLINE simplifyIxSize #-} @@ -262,7 +246,7 @@ isAdditionMonoid monoid = do BaseMonoid { baseEmpty = (Con (Lit l)) , baseCombine = BinaryLamExpr (b1:>_) (b2:>_) body } <- Just monoid unless (_isZeroLit l) Nothing - PrimOp (BinOp op (Var b1') (Var b2')) <- exprBlock body + PrimOp (BinOp op (Stuck _ (Var b1')) (Stuck _ (Var b2'))) <- return body unless (op `elem` [P.IAdd, P.FAdd]) Nothing case (binderName b1, atomVarName b1', binderName b2, atomVarName b2') of -- Checking the raw names here because (i) I don't know how to convince the @@ -306,7 +290,7 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where safe :: Effect SimpIR i -> TopVectorizeM i o Bool safe InitEffect = return True safe (RWSEffect Reader _) = return True - safe (RWSEffect Writer (Var h)) = do + safe (RWSEffect Writer (Stuck _ (Var h))) = do h' <- renameM $ atomVarName h commuteMap <- ask case lookupNameMapE h' commuteMap of @@ -319,27 +303,27 @@ vectorizeSeq :: Word32 -> IxType SimpIR i -> LamExpr SimpIR i -> TopVectorizeM i o (LamExpr SimpIR o) vectorizeSeq loopWidth ixty (UnaryLamExpr (b:>ty) body) = do newLoopTy <- case ty of - ProdTy [_ixType, ref] -> do + TyCon (ProdType [_ixType, ref]) -> do ref' <- renameM ref - return $ ProdTy [IdxRepTy, ref'] + return $ TyCon $ ProdType [IdxRepTy, ref'] _ -> error "Unexpected seq binder type" ixty' <- renameM ixty liftVectorizeM loopWidth $ buildUnaryLamExpr (getNameHint b) newLoopTy \ci -> do -- The per-tile loop iterates on `Fin` - (viOrd, dest) <- fromPair $ Var ci + (viOrd, dest) <- fromPair $ toAtom ci iOrd <- imul viOrd $ IdxRepVal loopWidth -- TODO: It would be nice to cancel this UnsafeFromOrdinal with the -- Ordinal that will be taken later when indexing, but that should -- probably be a separate pass. i <- applyIxMethod (sink $ ixTypeDict ixty') UnsafeFromOrdinal [iOrd] extendSubst (b @> VVal (ProdStability [Contiguous, ProdStability [Uniform]]) (PairVal i dest)) $ - vectorizeBlock body $> UnitVal + vectorizeExpr body $> UnitVal vectorizeSeq _ _ _ = error "expected a unary lambda expression" newtype VectorizeM i o a = VectorizeM { runVectorizeM :: - SubstReaderT VSubstValC (BuilderT SimpIR (ReaderT Word32 FallibleM)) i o a } + SubstReaderT VSubstValC (BuilderT SimpIR (ReaderT Word32 Except)) i o a } deriving ( Functor, Applicative, Monad, Fallible, MonadFail , SubstReader VSubstValC , Builder SimpIR, EnvReader, EnvExtender , ScopeReader, ScopableBuilder SimpIR) @@ -349,15 +333,14 @@ liftVectorizeM :: (SubstReader Name m, EnvReader (m i), Fallible (m i o)) liftVectorizeM loopWidth action = do subst <- getSubst act <- liftBuilderT $ runSubstReaderT (newSubst $ vSubst subst) $ runVectorizeM action - let fallible = flip runReaderT loopWidth act - case runFallibleM fallible of + case flip runReaderT loopWidth act of Success a -> return a - Failure errs -> throwErrs errs -- re-raise inside ambient monad + Failure errs -> throwErr errs -- re-raise inside ambient monad where vSubst subst val = VRename $ subst ! val getLoopWidth :: VectorizeM i o Word32 -getLoopWidth = VectorizeM $ SubstReaderT $ ReaderT $ const $ ask +getLoopWidth = VectorizeM $ SubstReaderT $ const $ ask -- TODO When needed, can code a variant of this that also returns the Stability -- of the value returned by the LamExpr. @@ -366,35 +349,24 @@ vectorizeLamExpr :: LamExpr SimpIR i -> [Stability] vectorizeLamExpr (LamExpr bs body) argStabilities = case (bs, argStabilities) of (Empty, []) -> do LamExpr Empty <$> buildBlock (do - vectorizeBlock body >>= \case + vectorizeExpr body >>= \case (VVal _ ans) -> return ans - (VRename v) -> Var <$> toAtomVar v) + (VRename v) -> toAtom <$> toAtomVar v) (Nest (b:>ty) rest, (stab:stabs)) -> do ty' <- vectorizeType ty ty'' <- promoteTypeByStability ty' stab withFreshBinder (getNameHint b) ty'' \b' -> do var <- toAtomVar $ binderName b' - extendSubst (b @> VVal stab (Var var)) do + extendSubst (b @> VVal stab (toAtom var)) do LamExpr rest' body' <- vectorizeLamExpr (LamExpr rest body) stabs return $ LamExpr (Nest b' rest') body' _ -> error "Zip error" -vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) -vectorizeBlock block@(Abs decls (ans :: SAtom i')) = - addVectErrCtx "vectorizeBlock" ("Block:\n" ++ pprint block) $ - go decls - where - go :: Emits o => Nest SDecl i i' -> VectorizeM i o (VAtom o) - go = \case - Empty -> vectorizeAtom ans - Nest (Let b (DeclBinding _ expr)) rest -> do - v <- vectorizeExpr expr - extendSubst (b @> v) $ go rest - vectorizeExpr :: Emits o => SExpr i -> VectorizeM i o (VAtom o) -vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do +vectorizeExpr expr = do case expr of - TabApp _ tbl [ix] -> do + Block _ block -> vectorizeBlock block + TabApp _ tbl ix -> do VVal Uniform tbl' <- vectorizeAtom tbl VVal Contiguous ix' <- vectorizeAtom ix case getType tbl' of @@ -402,13 +374,19 @@ vectorizeExpr expr = addVectErrCtx "vectorizeExpr" ("Expr:\n" ++ pprint expr) do vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Varying <$> emitOp (VectorOp $ VectorIdx tbl' ix' vty) + VVal Varying <$> emit (VectorIdx tbl' ix' vty) tblTy -> do throwVectErr $ "bad type: " ++ pprint tblTy ++ "\ntbl' : " ++ pprint tbl' Atom atom -> vectorizeAtom atom PrimOp op -> vectorizePrimOp op _ -> throwVectErr $ "Cannot vectorize expr: " ++ pprint expr +vectorizeBlock :: Emits o => SBlock i -> VectorizeM i o (VAtom o) +vectorizeBlock (Abs Empty body) = vectorizeExpr body +vectorizeBlock (Abs (Nest (Let b (DeclBinding _ rhs)) rest) body) = do + v <- vectorizeExpr rhs + extendSubst (b @> v) $ vectorizeBlock (Abs rest body) + vectorizeDAMOp :: Emits o => DAMOp SimpIR i -> VectorizeM i o (VAtom o) vectorizeDAMOp op = case op of @@ -416,11 +394,11 @@ vectorizeDAMOp op = VVal vref ref <- vectorizeAtom ref' sval@(VVal vval val) <- vectorizeAtom val' VVal Uniform <$> case (vref, vval) of - (Uniform , Uniform ) -> emitExpr $ PrimOp $ DAMOp $ Place ref val + (Uniform , Uniform ) -> emit $ Place ref val (Uniform , _ ) -> throwVectErr "Write conflict? This should never happen!" (Varying , _ ) -> throwVectErr "Vector scatter not implemented" - (Contiguous, Varying ) -> emitExpr $ PrimOp $ DAMOp $ Place ref val - (Contiguous, Contiguous) -> emitExpr . PrimOp . DAMOp . Place ref =<< ensureVarying sval + (Contiguous, Varying ) -> emit $ Place ref val + (Contiguous, Contiguous) -> emit . Place ref =<< ensureVarying sval _ -> throwVectErr "Not implemented yet" _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op @@ -431,7 +409,7 @@ vectorizeRefOp ref' op = -- TODO A contiguous reference becomes a vector load producing a varying -- result. VVal Uniform ref <- vectorizeAtom ref' - VVal Uniform <$> emitOp (RefOp ref MAsk) + VVal Uniform <$> emit (RefOp ref MAsk) MExtend basemonoid' x' -> do VVal refStab ref <- vectorizeAtom ref' VVal xStab x <- vectorizeAtom x' @@ -448,16 +426,16 @@ vectorizeRefOp ref' op = Contiguous -> do vectorizeBaseMonoid basemonoid' Varying xStab s -> throwVectErr $ "Cannot vectorize reference with loop-varying stability " ++ show s - VVal Uniform <$> emitOp (RefOp ref $ MExtend basemonoid x) + VVal Uniform <$> emit (RefOp ref $ MExtend basemonoid x) IndexRef _ i' -> do VVal Uniform ref <- vectorizeAtom ref' VVal Contiguous i <- vectorizeAtom i' case getType ref of - TC (RefType _ (TabTy _ tb a)) -> do + TyCon (RefType _ (TabTy _ tb a)) -> do vty <- getVectorType =<< case hoist tb a of HoistSuccess a' -> return a' HoistFailure _ -> throwVectErr "Can't vectorize dependent table application" - VVal Contiguous <$> emitOp (VectorOp $ VectorSubref ref i vty) + VVal Contiguous <$> emit (VectorSubref ref i vty) refTy -> do throwVectErr do "bad type: " ++ pprint refTy ++ "\nref' : " ++ pprint ref' @@ -483,7 +461,7 @@ vectorizePrimOp op = case op of sx@(VVal vx x) <- vectorizeAtom arg let v = case vx of Uniform -> Uniform; _ -> Varying x' <- if vx /= v then ensureVarying sx else return x - VVal v <$> emitOp (UnOp opk x') + VVal v <$> emit (UnOp opk x') BinOp opk arg1 arg2 -> do sx@(VVal vx x) <- vectorizeAtom arg1 sy@(VVal vy y) <- vectorizeAtom arg2 @@ -494,7 +472,7 @@ vectorizePrimOp op = case op of _ -> Varying x' <- if v == Varying then ensureVarying sx else return x y' <- if v == Varying then ensureVarying sy else return y - VVal v <$> emitOp (BinOp opk x' y') + VVal v <$> emit (BinOp opk x' y') MiscOp (CastOp tyArg arg) -> do ty <- vectorizeType tyArg VVal vx x <- vectorizeAtom arg @@ -503,28 +481,29 @@ vectorizePrimOp op = case op of Varying -> getVectorType ty Contiguous -> return ty ProdStability _ -> throwVectErr "Unexpected cast of product type" - VVal vx <$> emitOp (MiscOp $ CastOp ty' x) + VVal vx <$> emit (CastOp ty' x) DAMOp op' -> vectorizeDAMOp op' RefOp ref op' -> vectorizeRefOp ref op' MemOp (PtrOffset arg1 arg2) -> do VVal Uniform ptr <- vectorizeAtom arg1 VVal Contiguous off <- vectorizeAtom arg2 - VVal Contiguous <$> emitOp (MemOp $ PtrOffset ptr off) + VVal Contiguous <$> emit (PtrOffset ptr off) MemOp (PtrLoad arg) -> do VVal Contiguous ptr <- vectorizeAtom arg BaseTy (PtrType (addrSpace, a)) <- return $ getType ptr BaseTy av <- getVectorType $ BaseTy a - ptr' <- emitOp $ MiscOp $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr - VVal Varying <$> emitOp (MemOp $ PtrLoad ptr') + ptr' <- emit $ CastOp (BaseTy $ PtrType (addrSpace, av)) ptr + VVal Varying <$> emit (PtrLoad ptr') -- Vectorizing IO might not always be safe! Here, we depend on vectorizeOp -- being picky about the IO-inducing ops it supports, and expect it to -- complain about FFI calls and the like. Hof (TypedHof _ (RunIO body)) -> do -- TODO: buildBlockAux? Abs decls (LiftE vy `PairE` y) <- buildScoped do - VVal vy y <- vectorizeBlock body + VVal vy y <- vectorizeExpr body return $ PairE (LiftE vy) y - VVal vy <$> emitHof (RunIO $ Abs decls y) + block <- mkBlock (Abs decls y) + VVal vy <$> emitHof (RunIO block) _ -> throwVectErr $ "Can't vectorize op: " ++ pprint op vectorizeType :: SType i -> VectorizeM i o (SType o) @@ -533,24 +512,31 @@ vectorizeType t = do fmapNamesM (uniformSubst subst) t vectorizeAtom :: SAtom i -> VectorizeM i o (VAtom o) -vectorizeAtom atom = addVectErrCtx "vectorizeAtom" ("Atom:\n" ++ pprint atom) do +vectorizeAtom atom = do case atom of - Var v -> lookupSubstM (atomVarName v) >>= \case - VRename v' -> VVal Uniform . Var <$> toAtomVar v' - v' -> return v' - -- Vectors of base newtypes are already newtype-stripped. - ProjectElt _ (ProjectProduct i) x -> do - VVal vv x' <- vectorizeAtom x - ov <- case vv of - ProdStability sbs -> return $ sbs !! i - _ -> throwVectErr "Invalid projection" - x'' <- normalizeProj (ProjectProduct i) x' - return $ VVal ov x'' - ProjectElt _ UnwrapNewtype _ -> error "Shouldn't have newtypes left" -- TODO: check statically - Con (Lit l) -> return $ VVal Uniform $ Con $ Lit l - _ -> do - subst <- getSubst - VVal Uniform <$> fmapNamesM (uniformSubst subst) atom + Stuck _ e -> vectorizeStuck e + Con con -> case con of + Lit l -> return $ VVal Uniform $ Con $ Lit l + _ -> do + subst <- getSubst + VVal Uniform <$> fmapNamesM (uniformSubst subst) atom + +vectorizeStuck :: SStuck i -> VectorizeM i o (VAtom o) +vectorizeStuck = \case + Var v -> lookupSubstM (atomVarName v) >>= \case + VRename v' -> VVal Uniform . toAtom <$> toAtomVar v' + v' -> return v' + StuckProject i x -> do + VVal vv x' <- vectorizeStuck x + ov <- case vv of + ProdStability sbs -> return $ sbs !! i + _ -> throwVectErr "Invalid projection" + x'' <- reduceProj i x' + return $ VVal ov x'' + -- TODO: think about this case + StuckTabApp _ _ -> throwVectErr $ "Cannot vectorize atom" + PtrVar _ _ -> throwVectErr $ "Cannot vectorize atom" + RepValAtom _ -> throwVectErr $ "Cannot vectorize atom" uniformSubst :: Color c => Subst VSubstValC i o -> Name c i -> AtomSubstVal c o uniformSubst subst n = case subst ! n of @@ -560,33 +546,32 @@ uniformSubst subst n = case subst ! n of _ -> error "Can't vectorize atom" getVectorType :: SType o -> VectorizeM i o (SType o) -getVectorType ty = addVectErrCtx "getVectorType" ("Type:\n" ++ pprint ty) do - case ty of - BaseTy (Scalar sbt) -> do - els <- getLoopWidth - return $ BaseTy $ Vector [els] sbt - -- TODO: Should we support tables? - _ -> throwVectErr $ "Can't make a vector of " ++ pprint ty +getVectorType ty = case ty of + BaseTy (Scalar sbt) -> do + els <- getLoopWidth + return $ BaseTy $ Vector [els] sbt + -- TODO: Should we support tables? + _ -> throwVectErr $ "Can't make a vector of " ++ pprint ty ensureVarying :: Emits o => VAtom o -> VectorizeM i o (SAtom o) ensureVarying (VVal s val) = case s of Varying -> return val Uniform -> do vty <- getVectorType $ getType val - emitOp $ VectorOp $ VectorBroadcast val vty + emit $ VectorBroadcast val vty -- Note that the implementation of this case will depend on val's type. Contiguous -> do let ty = getType val vty <- getVectorType ty case ty of BaseTy (Scalar sbt) -> do - bval <- emitOp $ VectorOp $ VectorBroadcast val vty - iota <- emitOp $ VectorOp $ VectorIota vty - emitOp $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota + bval <- emit $ VectorBroadcast val vty + iota <- emit $ VectorIota vty + emit $ BinOp (if isIntegral sbt then IAdd else FAdd) bval iota _ -> throwVectErr "Not implemented" ProdStability _ -> throwVectErr "Not implemented" ensureVarying (VRename v) = do - x <- Var <$> toAtomVar v + x <- toAtom <$> toAtomVar v ensureVarying (VVal Uniform x) promoteTypeByStability :: SType o -> Stability -> VectorizeM i o (SType o) @@ -595,8 +580,8 @@ promoteTypeByStability ty = \case Contiguous -> return ty Varying -> getVectorType ty ProdStability stabs -> case ty of - ProdTy elts -> ProdTy <$> zipWithZ promoteTypeByStability elts stabs - _ -> throw ZipErr "Type and stability" + TyCon (ProdType elts) -> TyCon <$> ProdType <$> zipWithZ promoteTypeByStability elts stabs + _ -> throwInternal "Zip error" -- === computing byte widths === @@ -622,12 +607,14 @@ instance ExprVisitorNoEmits (CalcWidthM i o) SimpIR i o where let ty = getType expr' modify (\(LiftE x) -> LiftE $ min (typeByteWidth ty) x) return expr' + Block _ (Abs decls result) -> mkBlock =<< visitDeclsNoEmits decls \decls' -> do + Abs decls' <$> visitExprNoEmits result _ -> fallback where fallback = visitGeneric expr typeByteWidth :: SType n -> Word32 typeByteWidth ty = case ty of - TC (BaseType bt) -> case bt of + BaseTy bt -> case bt of -- Currently only support vectorization of scalar types (cf. `getVectorType` above): Scalar _ -> fromInteger . toInteger $ sizeOf bt _ -> maxWord32 diff --git a/static/dynamic.html b/static/dynamic.html index 5e636424a..eb0111d13 100644 --- a/static/dynamic.html +++ b/static/dynamic.html @@ -21,8 +21,10 @@ -
- +
+ (hover over code for more information) +
+
diff --git a/static/index.js b/static/index.js index 4b6862b09..6aa93849d 100644 --- a/static/index.js +++ b/static/index.js @@ -16,174 +16,211 @@ var katexOptions = { trust: true }; -var cells = {}; - -function append_contents(key, contents) { - if (key in cells) { - var cur_cells = cells[key]; - } else { - var cell = document.createElement("div"); - cell.className = "cell"; - cells[key] = [cell]; - var cur_cells = [cell]; - } - for (var i = 0; i < contents.length; i++) { - for (var j = 0; j < cur_cells.length; j++) { - var node = lookup_address(cur_cells[j], contents[i][0]) - node.innerHTML += contents[i][1]; - } - } -} - -function lookup_address(cell, address) { - var node = cell - for (i = 0; i < address.length; i++) { - node = node.children[address[i]] - } - return node -} - -function renderHovertips() { - var spans = document.querySelectorAll(".code-span"); - Array.from(spans).map((span) => attachHovertip(span)); -} - -function attachHovertip(node) { - node.addEventListener("mouseover", (event) => highlightNode( event, node)); - node.addEventListener("mouseout" , (event) => removeHighlighting(event, node)); -} - -function highlightNode(event, node) { - event.stopPropagation(); - node.style.backgroundColor = "lightblue"; - node.style.outlineColor = "lightblue"; - node.style.outlineStyle = "solid"; - Array.from(node.children).map(function (child) { - if (isCodeSpanOrLeaf(child)) { - child.style.backgroundColor = "yellow"; - } - }) -} - -function isCodeSpanOrLeaf(node) { - return node.classList.contains("code-span") || node.classList.contains("code-span-leaf") - -} - -function removeHighlighting(event, node) { - event.stopPropagation(); - node.style.backgroundColor = null; - node.style.outlineColor = null; - node.style.outlineStyle = null; - Array.from(node.children).map(function (child) { - if (isCodeSpanOrLeaf(child)) { - child.style.backgroundColor = null; - } - }) -} - -function renderLaTeX() { +function renderLaTeX(root) { // Render LaTeX equations in prose blocks via KaTeX, if available. // Skip rendering if KaTeX is unavailable. if (typeof renderMathInElement == 'undefined') { return; } // Render LaTeX equations in prose blocks via KaTeX. - var proseBlocks = document.querySelectorAll(".prose-block"); + var proseBlocks = root.querySelectorAll(".prose-block"); Array.from(proseBlocks).map((proseBlock) => renderMathInElement(proseBlock, katexOptions) ); } -/** - * Rendering the Table of Contents / Navigation Bar - * 2 key functions - * - `updateNavigation()` which inserts/updates the navigation bar - * - and it's helper `extractStructure()` which extracts the structure of the page - * and adds ids to heading elements. -*/ -function updateNavigation() { - function navItemList(struct) { - var listEle = document.createElement('ol') - struct.children.forEach(childStruct=> - listEle.appendChild(navItem(childStruct)) - ); - return listEle; +var RENDER_MODE = Object.freeze({ + STATIC: "static", + DYNAMIC: "dynamic", +}) +var body = document.getElementById("main-output"); +var hoverInfoDiv = document.getElementById("hover-info"); + +// State of the system beyond the HTML +var cells = {} +var frozenHover = false; +var curHighlights = []; // HTML elements currently highlighted +var highlightMap = {} +var spanMap = {} +var hoverInfoMap = {} + +function removeHover() { + if (frozenHover) return; + hoverInfoDiv.innerHTML = "" + curHighlights.map(function (element) { + element.classList.remove("highlighted", "highlighted-leaf")}) + curHighlights = []; +} +function lookupSrcMap(m, cellId, srcId) { + let blockMap = m[cellId] + if (blockMap == null) { + return null + } else { + return blockMap[srcId]} +} +function applyHover(cellId, srcId) { + if (frozenHover) return; + applyHoverInfo(cellId, srcId) + applyHoverHighlights(cellId, srcId) +} +function applyHoverInfo(cellId, srcId) { + let hoverInfo = lookupSrcMap(hoverInfoMap, cellId, srcId) + if (hoverInfo !== undefined) { + hoverInfoDiv.innerHTML = hoverInfo } - function navItem(struct) { - var a = document.createElement('a'); - a.appendChild(document.createTextNode(struct.text)); - a.title = struct.text; - a.href = "#"+struct.id; - - var ele = document.createElement('li') - ele.appendChild(a) - ele.appendChild(navItemList(struct)); - return ele; +} +function getSpan(cellId, srcId) { + return lookupSrcMap(spanMap, cellId, srcId) +} +function applyHoverHighlights(cellId, srcId) { + let highlights = lookupSrcMap(highlightMap, cellId, srcId) + if (highlights == null) return + highlights.map(function (highlight) { + let [highlightType, highlightSrcId] = highlight + let highlightClass = getHighlightClass(highlightType) + addClass(cellId, highlightSrcId, highlightClass)}) +} +function addClass(cellId, srcId, className) { + let span = getSpan(cellId, srcId) + if (span !== undefined) { + let [l, r] = span + let spans = spansBetween(selectSpan(cellId, l), selectSpan(cellId, r)); + spans.map(function (span) { + span.classList.add(className) + curHighlights.push(span)})} +} +function toggleFrozenHover() { + if (frozenHover) { + frozenHover = false + removeHover() + } else { + frozenHover = true} +} +function attachHovertip(cellId, srcId) { + let span = selectSpan(cellId, srcId) + span.addEventListener("mouseover", function (event) { + event.stopPropagation() + applyHover(cellId, srcId)}) + span.addEventListener("mouseout" , function (event) { + event.stopPropagation() + removeHover()})} +function selectSpan(cellId, srcId) { + return cells[cellId].querySelector("#span_".concat(cellId, "_", srcId)) +} +function selectCell(cellId) { + return cells[cellId] +} +function getHighlightClass(highlightType) { + if (highlightType == "HighlightGroup") { + return "highlighted"; + } else if (highlightType == "HighlightLeaf") { + return "highlighted-leaf"; + } else { + throw new Error("Unrecognized highlight type"); } - - var navbarEle = document.getElementById("navbar") - if (navbarEle === null) { // create it - navbarEle = document.createElement("div"); - navbarEle.id="navbar"; - navOuterEle = document.createElement("nav") - navOuterEle.appendChild(navbarEle); - document.body.prepend(navOuterEle); +} +function getStatusClass(status) { + if (status == "Waiting") { + return "waiting-cell"; + } else if (status == "Running") { + return "running-cell"; + } else if (status == "Complete") { + return "complete-cell"; + } else { + throw new Error("Unrecognized status type"); } - - navbarEle.innerHTML = "" - var structure = extractStructure() - navbarEle.appendChild(navItemList(structure)); } - -function extractStructure() { // Also sets ids on h1,h2,... - var headingsNodes = document.querySelectorAll("h1, h2, h3, h4, h5, h6"); - // For now we are just fulling going to regenerate the structure each time - // Might be better if we made minimal changes, but 🤷 - - // Extract the structure of the document - var structure = {children:[]} - var active = [structure.children]; - headingsNodes.forEach( - function(currentValue, currentIndex) { - currentValue.id = "s-" + currentIndex; - var currentLevel = parseInt(currentValue.nodeName[1]); - - // Insert dummy levels up for any levels that are skipped - for (var i=active.length; i < currentLevel; i++) { - var dummy = {id: "", text: "", children: []} - active.push(dummy.children); - var parentList = active[i-1] - parentList.push(dummy); - } - // delete this level and everything after - active.splice(currentLevel, active.length); - - var currentStructure = { - id: currentValue.id, - text: currentValue.textContent, - children: [], - }; - active.push(currentStructure.children); - - var parentList = active[active.length-2] - parentList.push(currentStructure); - }, - ); - return structure; +function spansBetween(l, r) { + let spans = [] + while (l !== null && !(Object.is(l, r))) { + spans.push(l); + l = l.nextSibling;} + spans.push(r) + return spans +} +function setCellStatus(cell, status) { + cell.className = "cell" + cell.classList.add(getStatusClass(status)) +} +function addChild(cell, className, innerHTML) { + let child = document.createElement("div") + child.innerHTML = innerHTML + child.className = className + cell.appendChild(child) +} +function initializeCellContents(cellId, cell, contents) { + let [source, status, result] = contents; + let lineNum = source["rsbLine"]; + let sourceText = source["rsbHtml"]; + highlightMap[cellId] = {}; + hoverInfoMap[cellId] = {}; + spanMap[cellId] = {}; + addChild(cell, "line-num" , lineNum.toString()) + addChild(cell, "code-block" , sourceText) + addChild(cell, "cell-results", "") + setCellStatus(cell, status) + renderLaTeX(cell) + extendCellResult(cellId, cell, result) } +function extendCellResult(cellId, cell, result) { + let resultText = result["rrHtml"] + if (resultText !== "") { + let bodyDiv = cell.querySelector(".cell-results") + bodyDiv.innerHTML += resultText + } + Object.assign(highlightMap[cellId], result["rrHighlightMap"]) + Object.assign(hoverInfoMap[cellId], result["rrHoverInfoMap"]) + Object.assign(spanMap[cellId] , result["rrLexemeSpans"]) -/** - * HTML rendering mode. - * Static rendering is used for static HTML pages. - * Dynamic rendering is used for dynamic HTML pages via `dex web`. - * - * @enum {string} - */ -var RENDER_MODE = Object.freeze({ - STATIC: "static", - DYNAMIC: "dynamic", -}) + let errSrcIds = result["rrErrorSrcIds"] + errSrcIds.map(function (srcId) { + addClass(cellId, srcId, "err-span")}) +} +function updateCellContents(cellId, cell, contents) { + let [statusUpdate, result] = contents; + if (statusUpdate["tag"] == "OverwriteWith") { + setCellStatus(cell, statusUpdate["contents"])} + extendCellResult(cellId, cell, result) +} +function processUpdate(msg) { + let cellUpdates = msg["nodeMapUpdate"]["mapUpdates"]; + let numDropped = msg["orderedNodesUpdate"]["numDropped"]; + let newTail = msg["orderedNodesUpdate"]["newTail"]; + // drop_dead_cells + for (i = 0; i < numDropped; i++) { + body.lastElementChild.remove();} + + Object.keys(cellUpdates).forEach(function (cellId) { + let update = cellUpdates[cellId]; + let tag = update["tag"] + let contents = update["contents"] + if (tag == "Create" || tag == "Replace") { + let cell = document.createElement("div"); + cells[cellId] = cell; + initializeCellContents(cellId, cell, contents) + } else if (tag == "Update") { + let cell = cells[cellId]; + updateCellContents(cellId, cell, contents); + } else if (tag == "Delete") { + delete cells[cellId] + } else { + console.error(tag); + }}); + + // append_new_cells + newTail.forEach(function (cellId) { + let cell = selectCell(cellId); + body.appendChild(cell);}) + + Object.keys(cellUpdates).forEach(function (cellId) { + let update = cellUpdates[cellId] + let tag = update["tag"] + if (tag == "Create" || tag == "Replace") { + let update = cellUpdates[cellId]; + let source = update["contents"][0]; + let lexemeList = source["rsbLexemeList"]; + lexemeList.map(function (lexemeId) {attachHovertip(cellId, lexemeId.toString())})}}); +} /** * Renders the webpage. @@ -192,48 +229,19 @@ var RENDER_MODE = Object.freeze({ function render(renderMode) { if (renderMode == RENDER_MODE.STATIC) { // For static pages, simply call rendering functions once. - renderLaTeX(); - renderHovertips(); - updateNavigation(); + renderLaTeX(document); } else { // For dynamic pages (via `dex web`), listen to update events. var source = new EventSource("/getnext"); source.onmessage = function(event) { - var body = document.getElementById("main-output"); var msg = JSON.parse(event.data); if (msg == "start") { - body.innerHTML = ""; + body.innerHTML = "" + body.addEventListener("click", function (event) { + event.stopPropagation() + toggleFrozenHover()}) cells = {} return - } - var order = msg[0]; - var contents = msg[1]; - for (var i = 0; i < contents.length; i++) { - append_contents(contents[i][0], contents[i][1]); - } - if (order != null) { - var new_cells = {}; - body.innerHTML = ""; - for (var i = 0; i < order.val.length; i++) { - var key = order.val[i] - var cur_cells = cells[key] - if (cur_cells.length == 0) { - var cur_cell = new_cells[key][0].cloneNode(true) - } else { - var cur_cell = cur_cells.pop() - if (key in new_cells) { - new_cells[key].push(cur_cell); - } else { - new_cells[key] = [cur_cell]; - } - } - body.appendChild(cur_cell); - } - Object.assign(cells, new_cells); - } - renderLaTeX(); - renderHovertips(); - updateNavigation(); - }; - } + } else { + processUpdate(msg)}};} } diff --git a/static/style.css b/static/style.css index cab311add..450f70bfd 100644 --- a/static/style.css +++ b/static/style.css @@ -11,49 +11,39 @@ body { display: flex; justify-content: space-between; overflow-x: hidden; - - --main-width: 50rem; - --nav-width: 20rem; -} - -@media (max-width: 70rem) { - /*For narrow screens hide nav and enable horizontal scrolling */ - nav {display: none;} - body {overflow-x: auto;} + padding-bottom:50vw; } -nav {/* this actually just holds space for #navbar, which is fixed */ - min-width: var(--nav-width); - max-width: var(--nav-width); -} -#navbar { +#hover-info { position: fixed; - height: 100vh; - width: var(--nav-width); - overflow-y: scroll; - border-right: 1px solid firebrick; -} -#navbar:before { - content: "Contents"; - font-weight: bold; -} -nav ol { - list-style-type:none; - padding-left: 1rem; + height: 5rem; + bottom: 0em; + width: 100vw; + overflow: hidden; + background-color: white; + border-top: 1px solid firebrick; + font-family: monospace; + white-space: pre; } #main-output { - max-width: var(--main-width); margin: auto; } -.code-block, .err-block, .result-block { +.code-block { +} + +.code-block, .cell-results, .err-block, .result-block { + margin: 0em 0em 0em 4em; padding: 0em 0em 0em 2em; display: block; font-family: monospace; white-space: pre; } - +.err-span { + text-decoration: red wavy underline; + text-decoration-skip-ink: none; +} code { background-color: #F0F0F0; } @@ -96,6 +86,14 @@ code { color: #E07000; } +.highlighted { + background-color: yellow; +} + +.highlighted-leaf { + background-color: lightblue; +} + .type-name { color: #A80000; } @@ -103,3 +101,26 @@ code { .iso-sugar { color: #25BBA7; } + +.cell { +} + +.line-num { + display: block; + font-family: monospace; + width: 3em; + color: #808080; + float: left; + text-align: right; +} + +.waiting-cell { + border-left: 6px solid #AAAAFF; +} + +.running-cell { + border-left: 6px solid #AAFFAA; +} + +.complete-cell { +}