diff --git a/array_api_strict/_statistical_functions.py b/array_api_strict/_statistical_functions.py index 35876dda..023bb871 100644 --- a/array_api_strict/_statistical_functions.py +++ b/array_api_strict/_statistical_functions.py @@ -153,8 +153,8 @@ def std( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in std") + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in std") return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device) @@ -185,6 +185,6 @@ def var( keepdims: bool = False, ) -> Array: # Note: the keyword argument correction is different here - if x.dtype not in _real_floating_dtypes: - raise TypeError("Only real floating-point dtypes are allowed in var") + if x.dtype not in _floating_dtypes: + raise TypeError("Only floating-point dtypes are allowed in var") return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)