forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
autodiff_basic.sil
90 lines (80 loc) · 5.87 KB
/
autodiff_basic.sil
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
// RUN: %target-sil-opt -differentiation %s | %FileCheck %s
import Builtin
import Swift
sil_stage raw
sil hidden @foo_adj : $@convention(thin) (Float, Float, Float) -> Float {
bb0(%0 : @trivial $Float, %1 : @trivial $Float, %2 : @trivial $Float):
return %2 : $Float
}
sil hidden [reverse_differentiable source 0 wrt 0 primal @foo adjoint @foo_adj] @foo : $@convention(thin) (Float) -> Float {
bb0(%0 : @trivial $Float):
return %0 : $Float
}
sil hidden @bar : $@convention(thin) (Float) -> (Float, Float) {
bb0(%0 : @trivial $Float):
%1 = function_ref @foo : $@convention(thin) (Float) -> Float
%2 = gradient [source 0] [wrt 0] %1 : $@convention(thin) (Float) -> Float
%3 = apply %2(%0) : $@convention(thin) (Float) -> Float
%4 = gradient [source 0] [wrt 0] [preserving_result] %1 : $@convention(thin) (Float) -> Float
%5 = apply %4(%3) : $@convention(thin) (Float) -> (Float, Float)
return %5 : $(Float, Float)
}
// Here all `gradient` instructions have been replaced by `function_ref`s.
// CHECK-LABEL: sil hidden @bar :
// CHECK: bb0
// CHECK: %1 = function_ref @AD__foo__grad_src_0_wrt_0 : $@convention(thin) (Float) -> Float
// CHECK: %2 = apply %1(%0) : $@convention(thin) (Float) -> Float
// CHECK: %3 = function_ref @AD__foo__grad_src_0_wrt_0_p : $@convention(thin) (Float) -> (Float, Float)
// CHECK: %4 = apply %3(%2) : $@convention(thin) (Float) -> (Float, Float)
// CHECK: return %4 : $(Float, Float)
// CHECK: }
// CHECK-LABEL:sil hidden @AD__foo__grad_src_0_wrt_0_s_p :
// CHECK: bb0
// CHECK: %2 = function_ref @foo : $@convention(thin) (Float) -> Float
// CHECK: %3 = apply %2(%0) : $@convention(thin) (Float) -> Float
// CHECK: %4 = function_ref @foo_adj : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %5 = apply %4(%0, %3, %1) : $@convention(thin) (Float, Float, Float) -> Float
// CHECK: %6 = tuple (%3 : $Float, %5 : $Float)
// CHECK: return %6 : $(Float, Float)
// CHECK: }
// CHECK-LABEL: sil hidden @AD__foo__grad_src_0_wrt_0 : $@convention(thin) (Float) -> Float {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $Float
// CHECK: %2 = integer_literal $Builtin.Int2048, 1
// CHECK: %3 = metatype $@thick Int64.Type
// CHECK: %4 = witness_method $Int64, #_ExpressibleByBuiltinIntegerLiteral.init!allocator.1 : <Self where Self : _ExpressibleByBuiltinIntegerLiteral> (Self.Type) -> (Builtin.Int2048) -> Self : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.Int2048, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = alloc_stack $Int64
// CHECK: %6 = apply %4<Int64>(%5, %2, %3) : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.Int2048, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %7 = metatype $@thick Float.Type
// CHECK: %8 = witness_method $Float, #ExpressibleByIntegerLiteral.init!allocator.1 : <Self where Self : ExpressibleByIntegerLiteral> (Self.Type) -> (Self.IntegerLiteralType) -> Self : $@convention(witness_method: ExpressibleByIntegerLiteral) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral> (@in τ_0_0.IntegerLiteralType, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %9 = apply %8<Float>(%1, %5, %7) : $@convention(witness_method: ExpressibleByIntegerLiteral) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral> (@in τ_0_0.IntegerLiteralType, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: dealloc_stack %5 : $*Int64
// CHECK: %11 = begin_access [read] [static] %1 : $*Float
// CHECK: %12 = load %11 : $*Float
// CHECK: end_access %11 : $*Float
// CHECK: dealloc_stack %1 : $*Float
// CHECK: %15 = function_ref @AD__foo__grad_src_0_wrt_0_s_p : $@convention(thin) (Float, Float) -> (Float, Float)
// CHECK: %16 = apply %15(%0, %12) : $@convention(thin) (Float, Float) -> (Float, Float)
// CHECK: %17 = tuple_extract %16 : $(Float, Float), 1
// CHECK: return %17 : $Float
// CHECK: }
// CHECK-LABEL: sil hidden @AD__foo__grad_src_0_wrt_0_p : $@convention(thin) (Float) -> (Float, Float) {
// CHECK: bb0(%0 : $Float):
// CHECK: %1 = alloc_stack $Float
// CHECK: %2 = integer_literal $Builtin.Int2048, 1
// CHECK: %3 = metatype $@thick Int64.Type
// CHECK: %4 = witness_method $Int64, #_ExpressibleByBuiltinIntegerLiteral.init!allocator.1 : <Self where Self : _ExpressibleByBuiltinIntegerLiteral> (Self.Type) -> (Builtin.Int2048) -> Self : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.Int2048, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %5 = alloc_stack $Int64
// CHECK: %6 = apply %4<Int64>(%5, %2, %3) : $@convention(witness_method: _ExpressibleByBuiltinIntegerLiteral) <τ_0_0 where τ_0_0 : _ExpressibleByBuiltinIntegerLiteral> (Builtin.Int2048, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %7 = metatype $@thick Float.Type
// CHECK: %8 = witness_method $Float, #ExpressibleByIntegerLiteral.init!allocator.1 : <Self where Self : ExpressibleByIntegerLiteral> (Self.Type) -> (Self.IntegerLiteralType) -> Self : $@convention(witness_method: ExpressibleByIntegerLiteral) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral> (@in τ_0_0.IntegerLiteralType, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: %9 = apply %8<Float>(%1, %5, %7) : $@convention(witness_method: ExpressibleByIntegerLiteral) <τ_0_0 where τ_0_0 : ExpressibleByIntegerLiteral> (@in τ_0_0.IntegerLiteralType, @thick τ_0_0.Type) -> @out τ_0_0
// CHECK: dealloc_stack %5 : $*Int64
// CHECK: %11 = begin_access [read] [static] %1 : $*Float
// CHECK: %12 = load %11 : $*Float
// CHECK: end_access %11 : $*Float
// CHECK: dealloc_stack %1 : $*Float
// CHECK: %15 = function_ref @AD__foo__grad_src_0_wrt_0_s_p : $@convention(thin) (Float, Float) -> (Float, Float)
// CHECK: %16 = apply %15(%0, %12) : $@convention(thin) (Float, Float) -> (Float, Float)
// CHECK: return %16 : $(Float, Float)
// CHECK: }