77import  inspect 
88from  typing  import  TYPE_CHECKING , Any , NamedTuple , Optional , Sequence , cast 
99
10- from  ._helpers  import  _check_device , array_namespace 
10+ from  ._helpers  import  _device_ctx , array_namespace 
1111from  ._helpers  import  device  as  _get_device 
1212from  ._helpers  import  is_cupy_namespace  as  _is_cupy_namespace 
1313from  ._typing  import  Array , Device , DType , Namespace 
@@ -32,8 +32,8 @@ def arange(
3232    device : Device  |  None  =  None ,
3333    ** kwargs : object ,
3434) ->  Array :
35-     _check_device (xp , device )
36-     return  xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
35+     with   _device_ctx (xp , device ): 
36+          return  xp .arange (start , stop = stop , step = step , dtype = dtype , ** kwargs )
3737
3838
3939def  empty (
@@ -44,8 +44,8 @@ def empty(
4444    device : Device  |  None  =  None ,
4545    ** kwargs : object ,
4646) ->  Array :
47-     _check_device (xp , device )
48-     return  xp .empty (shape , dtype = dtype , ** kwargs )
47+     with   _device_ctx (xp , device ): 
48+          return  xp .empty (shape , dtype = dtype , ** kwargs )
4949
5050
5151def  empty_like (
@@ -57,8 +57,8 @@ def empty_like(
5757    device : Device  |  None  =  None ,
5858    ** kwargs : object ,
5959) ->  Array :
60-     _check_device (xp , device ) 
61-     return  xp .empty_like (x , dtype = dtype , ** kwargs )
60+     with   _device_ctx (xp , device ,  like = x ): 
61+          return  xp .empty_like (x , dtype = dtype , ** kwargs )
6262
6363
6464def  eye (
@@ -72,8 +72,8 @@ def eye(
7272    device : Device  |  None  =  None ,
7373    ** kwargs : object ,
7474) ->  Array :
75-     _check_device (xp , device )
76-     return  xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
75+     with   _device_ctx (xp , device ): 
76+          return  xp .eye (n_rows , M = n_cols , k = k , dtype = dtype , ** kwargs )
7777
7878
7979def  full (
@@ -85,8 +85,8 @@ def full(
8585    device : Device  |  None  =  None ,
8686    ** kwargs : object ,
8787) ->  Array :
88-     _check_device (xp , device )
89-     return  xp .full (shape , fill_value , dtype = dtype , ** kwargs )
88+     with   _device_ctx (xp , device ): 
89+          return  xp .full (shape , fill_value , dtype = dtype , ** kwargs )
9090
9191
9292def  full_like (
@@ -99,8 +99,8 @@ def full_like(
9999    device : Device  |  None  =  None ,
100100    ** kwargs : object ,
101101) ->  Array :
102-     _check_device (xp , device ) 
103-     return  xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
102+     with   _device_ctx (xp , device ,  like = x ): 
103+          return  xp .full_like (x , fill_value , dtype = dtype , ** kwargs )
104104
105105
106106def  linspace (
@@ -115,8 +115,8 @@ def linspace(
115115    endpoint : bool  =  True ,
116116    ** kwargs : object ,
117117) ->  Array :
118-     _check_device (xp , device )
119-     return  xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
118+     with   _device_ctx (xp , device ): 
119+          return  xp .linspace (start , stop , num , dtype = dtype , endpoint = endpoint , ** kwargs )
120120
121121
122122def  ones (
@@ -127,8 +127,8 @@ def ones(
127127    device : Device  |  None  =  None ,
128128    ** kwargs : object ,
129129) ->  Array :
130-     _check_device (xp , device )
131-     return  xp .ones (shape , dtype = dtype , ** kwargs )
130+     with   _device_ctx (xp , device ): 
131+          return  xp .ones (shape , dtype = dtype , ** kwargs )
132132
133133
134134def  ones_like (
@@ -140,8 +140,8 @@ def ones_like(
140140    device : Device  |  None  =  None ,
141141    ** kwargs : object ,
142142) ->  Array :
143-     _check_device (xp , device ) 
144-     return  xp .ones_like (x , dtype = dtype , ** kwargs )
143+     with   _device_ctx (xp , device ,  like = x ): 
144+          return  xp .ones_like (x , dtype = dtype , ** kwargs )
145145
146146
147147def  zeros (
@@ -152,8 +152,8 @@ def zeros(
152152    device : Device  |  None  =  None ,
153153    ** kwargs : object ,
154154) ->  Array :
155-     _check_device (xp , device )
156-     return  xp .zeros (shape , dtype = dtype , ** kwargs )
155+     with   _device_ctx (xp , device ): 
156+          return  xp .zeros (shape , dtype = dtype , ** kwargs )
157157
158158
159159def  zeros_like (
@@ -165,8 +165,8 @@ def zeros_like(
165165    device : Device  |  None  =  None ,
166166    ** kwargs : object ,
167167) ->  Array :
168-     _check_device (xp , device ) 
169-     return  xp .zeros_like (x , dtype = dtype , ** kwargs )
168+     with   _device_ctx (xp , device ,  like = x ): 
169+          return  xp .zeros_like (x , dtype = dtype , ** kwargs )
170170
171171
172172# np.unique() is split into four functions in the array API: 
0 commit comments