forked from swiftlang/swift
-
Notifications
You must be signed in to change notification settings - Fork 0
/
differentiable_attr_parse.swift
75 lines (60 loc) · 2.11 KB
/
differentiable_attr_parse.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// RUN: %target-swift-frontend -parse -verify %s
/// Good
@differentiable(reverse, adjoint: foo(_:_:)) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, adjoint: foo(_:_:) where T : FloatingPoint) // okay
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
}
@differentiable(reverse, wrt: (self, .0, .1), adjoint: foo(_:_:)) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, wrt: (self, .0, .1), primal: bar, adjoint: foo(_:_:)) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse) // okay
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@_transparent
@differentiable(reverse) // okay
@inlinable
func playWellWithOtherAttrs(_ x: Float, _: Float) -> Float {
return 1 + x
}
@_transparent
@differentiable(reverse, wrt: (self), adjoint: _adjointSquareRoot) // okay
public func squareRoot() -> Self {
var lhs = self
lhs.formSquareRoot()
return lhs
}
/// Bad
@differentiable(primal: bar) // expected-error {{expected a differentiation mode ('forward' or 'reverse')}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, 3) // expected-error {{expected a configuration, e.g. 'wrt:', 'primal:' or 'adjoint:'}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, foo(_:_:)) // expected-error {{expected a configuration, e.g. 'wrt:', 'primal:' or 'adjoint:'}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, wrt: (1), adjoint: foo(_:_:)) // expected-error {{expected a parameter, which can be the index of a function parameter with a leading dot (e.g. '.0'), or 'self'}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, adjoint: foo(_:_:) // expected-error {{expected ')' in 'differentiable' attribute}}
func bar(_ x: Float, _: Float) -> Float {
return 1 + x
}
@differentiable(reverse, adjoint: foo(_:_:) where T) // expected-error {{expected ':' or '==' to indicate a conformance or same-type requirement}}
func bar<T : Numeric>(_ x: T, _: T) -> T {
return 1 + x
}