Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions array_api_strict/_statistical_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def std(
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in std")
if x.dtype not in _floating_dtypes:

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gate on the array API version?

raise TypeError("Only floating-point dtypes are allowed in std")
return Array._new(np.std(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)


Expand Down Expand Up @@ -185,6 +185,6 @@ def var(
keepdims: bool = False,
) -> Array:
# Note: the keyword argument correction is different here
if x.dtype not in _real_floating_dtypes:
raise TypeError("Only real floating-point dtypes are allowed in var")
if x.dtype not in _floating_dtypes:
raise TypeError("Only floating-point dtypes are allowed in var")
return Array._new(np.var(x._array, axis=axis, ddof=correction, keepdims=keepdims), device=x.device)
Loading