diff --git a/src/mcp/client/auth/utils.py b/src/mcp/client/auth/utils.py index d6b05e066..39e6bcd94 100644 --- a/src/mcp/client/auth/utils.py +++ b/src/mcp/client/auth/utils.py @@ -1,4 +1,3 @@ -import re from urllib.parse import urljoin, urlparse from httpx import Request, Response @@ -16,6 +15,45 @@ from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER +def _iter_www_auth_params(www_auth_header: str) -> list[str]: + """Split a WWW-Authenticate challenge into auth-param tokens.""" + params_start = www_auth_header.find(" ") + if params_start == -1: + return [] + + params: list[str] = [] + current: list[str] = [] + in_quotes = False + escape_next = False + + for char in www_auth_header[params_start + 1 :]: + if escape_next: + current.append(char) + escape_next = False + continue + if char == "\\" and in_quotes: + current.append(char) + escape_next = True + continue + if char == '"': + in_quotes = not in_quotes + current.append(char) + continue + if char == "," and not in_quotes: + param = "".join(current).strip() + if param: + params.append(param) + current = [] + continue + current.append(char) + + param = "".join(current).strip() + if param: + params.append(param) + + return params + + def extract_field_from_www_auth(response: Response, field_name: str) -> str | None: """Extract field from WWW-Authenticate header. @@ -26,13 +64,16 @@ def extract_field_from_www_auth(response: Response, field_name: str) -> str | No if not www_auth_header: return None - # Pattern matches: field_name="value" or field_name=value (unquoted) - pattern = rf'{field_name}=(?:"([^"]+)"|([^\s,]+))' - match = re.search(pattern, www_auth_header) + for param in _iter_www_auth_params(www_auth_header): + name, separator, value = param.partition("=") + if separator != "=" or name.strip() != field_name: + continue - if match: - # Return quoted value if present, otherwise unquoted value - return match.group(1) or match.group(2) + value = value.strip() + if len(value) >= 2 and value[0] == value[-1] == '"': + value = value[1:-1] + if value: + return value return None diff --git a/tests/client/test_auth.py b/tests/client/test_auth.py index 1ec38ccf6..25fda1de9 100644 --- a/tests/client/test_auth.py +++ b/tests/client/test_auth.py @@ -2047,6 +2047,7 @@ def test_extract_field_from_www_auth_valid_cases( # Header without requested field ('Bearer realm="api", error="insufficient_scope"', "scope", "no scope parameter"), ('Bearer realm="api", scope="read write"', "resource_metadata", "no resource_metadata parameter"), + ("Bearer", "scope", "no auth parameters"), # Malformed field (empty value) ("Bearer scope=", "scope", "malformed scope parameter"), ("Bearer resource_metadata=", "resource_metadata", "malformed resource_metadata parameter"), @@ -2070,6 +2071,153 @@ def test_extract_field_from_www_auth_invalid_cases( result = extract_field_from_www_auth(init_response, field_name) assert result is None, f"Should return None for {description}" + def test_extract_field_from_www_auth_does_not_match_substring_param_name( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test auth-param names are matched exactly, not as substrings.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer error_scope="decoy", scope="read write"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result == "read write" + + def test_extract_field_from_www_auth_ignores_prefixed_param_only( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test a prefixed auth-param does not satisfy the requested field.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer custom_scope="leaked"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result is None + + def test_extract_resource_metadata_from_www_auth_ignores_prefixed_param( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test resource_metadata does not match inside another auth-param name.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer x_resource_metadata="https://decoy.example.com"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_resource_metadata_from_www_auth(init_response) + assert result is None + + def test_extract_field_from_www_auth_ignores_param_like_text_inside_quoted_value( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test quoted values cannot shadow a later auth-param with the same name.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer realm="api, scope=decoy", scope="read write"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result == "read write" + + def test_extract_field_from_www_auth_ignores_quoted_value_when_only_decoy_exists( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test a field-like string inside a quoted value is not an auth-param.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer realm="api scope=leaked"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result is None + + def test_extract_field_from_www_auth_handles_escaped_quote_inside_quoted_value( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test escaped characters inside a quoted value do not break splitting.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer realm="api \\"scope\\", still realm", scope="read write"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result == "read write" + + def test_extract_field_from_www_auth_ignores_empty_comma_segments( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test empty segments between commas are ignored while parsing.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer scope="read write", , error="insufficient_scope"'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result == "read write" + + def test_extract_field_from_www_auth_ignores_trailing_comma( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test a trailing comma does not create a malformed final param.""" + + init_response = httpx.Response( + status_code=401, + headers={"WWW-Authenticate": 'Bearer scope="read write",'}, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_field_from_www_auth(init_response, "scope") + assert result == "read write" + + def test_extract_resource_metadata_from_www_auth_ignores_quoted_value_decoy( + self, + client_metadata: OAuthClientMetadata, + mock_storage: MockTokenStorage, + ): + """Test resource_metadata is not extracted from another quoted param value.""" + + init_response = httpx.Response( + status_code=401, + headers={ + "WWW-Authenticate": 'Bearer realm="api, resource_metadata=https://decoy.example.com", ' + 'resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"' + }, + request=httpx.Request("GET", "https://api.example.com/test"), + ) + + result = extract_resource_metadata_from_www_auth(init_response) + assert result == "https://api.example.com/.well-known/oauth-protected-resource" + class TestCIMD: """Test Client ID Metadata Document (CIMD) support."""