Spaces:
Paused
Paused
| from typing import Any, Callable, Generator, List | |
| import pytest | |
| from .._events import ( | |
| ConnectionClosed, | |
| Data, | |
| EndOfMessage, | |
| Event, | |
| InformationalResponse, | |
| Request, | |
| Response, | |
| ) | |
| from .._headers import Headers, normalize_and_validate | |
| from .._readers import ( | |
| _obsolete_line_fold, | |
| ChunkedReader, | |
| ContentLengthReader, | |
| Http10Reader, | |
| READERS, | |
| ) | |
| from .._receivebuffer import ReceiveBuffer | |
| from .._state import ( | |
| CLIENT, | |
| CLOSED, | |
| DONE, | |
| IDLE, | |
| MIGHT_SWITCH_PROTOCOL, | |
| MUST_CLOSE, | |
| SEND_BODY, | |
| SEND_RESPONSE, | |
| SERVER, | |
| SWITCHED_PROTOCOL, | |
| ) | |
| from .._util import LocalProtocolError | |
| from .._writers import ( | |
| ChunkedWriter, | |
| ContentLengthWriter, | |
| Http10Writer, | |
| write_any_response, | |
| write_headers, | |
| write_request, | |
| WRITERS, | |
| ) | |
| from .helpers import normalize_data_events | |
| SIMPLE_CASES = [ | |
| ( | |
| (CLIENT, IDLE), | |
| Request( | |
| method="GET", | |
| target="/a", | |
| headers=[("Host", "foo"), ("Connection", "close")], | |
| ), | |
| b"GET /a HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", | |
| ), | |
| ( | |
| (SERVER, SEND_RESPONSE), | |
| Response(status_code=200, headers=[("Connection", "close")], reason=b"OK"), | |
| b"HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", | |
| ), | |
| ( | |
| (SERVER, SEND_RESPONSE), | |
| Response(status_code=200, headers=[], reason=b"OK"), # type: ignore[arg-type] | |
| b"HTTP/1.1 200 OK\r\n\r\n", | |
| ), | |
| ( | |
| (SERVER, SEND_RESPONSE), | |
| InformationalResponse( | |
| status_code=101, headers=[("Upgrade", "websocket")], reason=b"Upgrade" | |
| ), | |
| b"HTTP/1.1 101 Upgrade\r\nUpgrade: websocket\r\n\r\n", | |
| ), | |
| ( | |
| (SERVER, SEND_RESPONSE), | |
| InformationalResponse(status_code=101, headers=[], reason=b"Upgrade"), # type: ignore[arg-type] | |
| b"HTTP/1.1 101 Upgrade\r\n\r\n", | |
| ), | |
| ] | |
| def dowrite(writer: Callable[..., None], obj: Any) -> bytes: | |
| got_list: List[bytes] = [] | |
| writer(obj, got_list.append) | |
| return b"".join(got_list) | |
| def tw(writer: Any, obj: Any, expected: Any) -> None: | |
| got = dowrite(writer, obj) | |
| assert got == expected | |
| def makebuf(data: bytes) -> ReceiveBuffer: | |
| buf = ReceiveBuffer() | |
| buf += data | |
| return buf | |
| def tr(reader: Any, data: bytes, expected: Any) -> None: | |
| def check(got: Any) -> None: | |
| assert got == expected | |
| # Headers should always be returned as bytes, not e.g. bytearray | |
| # https://github.com/python-hyper/wsproto/pull/54#issuecomment-377709478 | |
| for name, value in getattr(got, "headers", []): | |
| assert type(name) is bytes | |
| assert type(value) is bytes | |
| # Simple: consume whole thing | |
| buf = makebuf(data) | |
| check(reader(buf)) | |
| assert not buf | |
| # Incrementally growing buffer | |
| buf = ReceiveBuffer() | |
| for i in range(len(data)): | |
| assert reader(buf) is None | |
| buf += data[i : i + 1] | |
| check(reader(buf)) | |
| # Trailing data | |
| buf = makebuf(data) | |
| buf += b"trailing" | |
| check(reader(buf)) | |
| assert bytes(buf) == b"trailing" | |
| def test_writers_simple() -> None: | |
| for ((role, state), event, binary) in SIMPLE_CASES: | |
| tw(WRITERS[role, state], event, binary) | |
| def test_readers_simple() -> None: | |
| for ((role, state), event, binary) in SIMPLE_CASES: | |
| tr(READERS[role, state], binary, event) | |
| def test_writers_unusual() -> None: | |
| # Simple test of the write_headers utility routine | |
| tw( | |
| write_headers, | |
| normalize_and_validate([("foo", "bar"), ("baz", "quux")]), | |
| b"foo: bar\r\nbaz: quux\r\n\r\n", | |
| ) | |
| tw(write_headers, Headers([]), b"\r\n") | |
| # We understand HTTP/1.0, but we don't speak it | |
| with pytest.raises(LocalProtocolError): | |
| tw( | |
| write_request, | |
| Request( | |
| method="GET", | |
| target="/", | |
| headers=[("Host", "foo"), ("Connection", "close")], | |
| http_version="1.0", | |
| ), | |
| None, | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tw( | |
| write_any_response, | |
| Response( | |
| status_code=200, headers=[("Connection", "close")], http_version="1.0" | |
| ), | |
| None, | |
| ) | |
| def test_readers_unusual() -> None: | |
| # Reading HTTP/1.0 | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.0\r\nSome: header\r\n\r\n", | |
| Request( | |
| method="HEAD", | |
| target="/foo", | |
| headers=[("Some", "header")], | |
| http_version="1.0", | |
| ), | |
| ) | |
| # check no-headers, since it's only legal with HTTP/1.0 | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.0\r\n\r\n", | |
| Request(method="HEAD", target="/foo", headers=[], http_version="1.0"), # type: ignore[arg-type] | |
| ) | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.0 200 OK\r\nSome: header\r\n\r\n", | |
| Response( | |
| status_code=200, | |
| headers=[("Some", "header")], | |
| http_version="1.0", | |
| reason=b"OK", | |
| ), | |
| ) | |
| # single-character header values (actually disallowed by the ABNF in RFC | |
| # 7230 -- this is a bug in the standard that we originally copied...) | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.0 200 OK\r\n" b"Foo: a a a a a \r\n\r\n", | |
| Response( | |
| status_code=200, | |
| headers=[("Foo", "a a a a a")], | |
| http_version="1.0", | |
| reason=b"OK", | |
| ), | |
| ) | |
| # Empty headers -- also legal | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.0 200 OK\r\n" b"Foo:\r\n\r\n", | |
| Response( | |
| status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK" | |
| ), | |
| ) | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.0 200 OK\r\n" b"Foo: \t \t \r\n\r\n", | |
| Response( | |
| status_code=200, headers=[("Foo", "")], http_version="1.0", reason=b"OK" | |
| ), | |
| ) | |
| # Tolerate broken servers that leave off the response code | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.0 200\r\n" b"Foo: bar\r\n\r\n", | |
| Response( | |
| status_code=200, headers=[("Foo", "bar")], http_version="1.0", reason=b"" | |
| ), | |
| ) | |
| # Tolerate headers line endings (\r\n and \n) | |
| # \n\r\b between headers and body | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.1 200 OK\r\nSomeHeader: val\n\r\n", | |
| Response( | |
| status_code=200, | |
| headers=[("SomeHeader", "val")], | |
| http_version="1.1", | |
| reason="OK", | |
| ), | |
| ) | |
| # delimited only with \n | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.1 200 OK\nSomeHeader1: val1\nSomeHeader2: val2\n\n", | |
| Response( | |
| status_code=200, | |
| headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")], | |
| http_version="1.1", | |
| reason="OK", | |
| ), | |
| ) | |
| # mixed \r\n and \n | |
| tr( | |
| READERS[SERVER, SEND_RESPONSE], | |
| b"HTTP/1.1 200 OK\r\nSomeHeader1: val1\nSomeHeader2: val2\n\r\n", | |
| Response( | |
| status_code=200, | |
| headers=[("SomeHeader1", "val1"), ("SomeHeader2", "val2")], | |
| http_version="1.1", | |
| reason="OK", | |
| ), | |
| ) | |
| # obsolete line folding | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" | |
| b"Host: example.com\r\n" | |
| b"Some: multi-line\r\n" | |
| b" header\r\n" | |
| b"\tnonsense\r\n" | |
| b" \t \t\tI guess\r\n" | |
| b"Connection: close\r\n" | |
| b"More-nonsense: in the\r\n" | |
| b" last header \r\n\r\n", | |
| Request( | |
| method="HEAD", | |
| target="/foo", | |
| headers=[ | |
| ("Host", "example.com"), | |
| ("Some", "multi-line header nonsense I guess"), | |
| ("Connection", "close"), | |
| ("More-nonsense", "in the last header"), | |
| ], | |
| ), | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" b" folded: line\r\n\r\n", | |
| None, | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" b"foo : line\r\n\r\n", | |
| None, | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n", | |
| None, | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" b"foo\t: line\r\n\r\n", | |
| None, | |
| ) | |
| with pytest.raises(LocalProtocolError): | |
| tr(READERS[CLIENT, IDLE], b"HEAD /foo HTTP/1.1\r\n" b": line\r\n\r\n", None) | |
| def test__obsolete_line_fold_bytes() -> None: | |
| # _obsolete_line_fold has a defensive cast to bytearray, which is | |
| # necessary to protect against O(n^2) behavior in case anyone ever passes | |
| # in regular bytestrings... but right now we never pass in regular | |
| # bytestrings. so this test just exists to get some coverage on that | |
| # defensive cast. | |
| assert list(_obsolete_line_fold([b"aaa", b"bbb", b" ccc", b"ddd"])) == [ | |
| b"aaa", | |
| bytearray(b"bbb ccc"), | |
| b"ddd", | |
| ] | |
| def _run_reader_iter( | |
| reader: Any, buf: bytes, do_eof: bool | |
| ) -> Generator[Any, None, None]: | |
| while True: | |
| event = reader(buf) | |
| if event is None: | |
| break | |
| yield event | |
| # body readers have undefined behavior after returning EndOfMessage, | |
| # because this changes the state so they don't get called again | |
| if type(event) is EndOfMessage: | |
| break | |
| if do_eof: | |
| assert not buf | |
| yield reader.read_eof() | |
| def _run_reader(*args: Any) -> List[Event]: | |
| events = list(_run_reader_iter(*args)) | |
| return normalize_data_events(events) | |
| def t_body_reader(thunk: Any, data: bytes, expected: Any, do_eof: bool = False) -> None: | |
| # Simple: consume whole thing | |
| print("Test 1") | |
| buf = makebuf(data) | |
| assert _run_reader(thunk(), buf, do_eof) == expected | |
| # Incrementally growing buffer | |
| print("Test 2") | |
| reader = thunk() | |
| buf = ReceiveBuffer() | |
| events = [] | |
| for i in range(len(data)): | |
| events += _run_reader(reader, buf, False) | |
| buf += data[i : i + 1] | |
| events += _run_reader(reader, buf, do_eof) | |
| assert normalize_data_events(events) == expected | |
| is_complete = any(type(event) is EndOfMessage for event in expected) | |
| if is_complete and not do_eof: | |
| buf = makebuf(data + b"trailing") | |
| assert _run_reader(thunk(), buf, False) == expected | |
| def test_ContentLengthReader() -> None: | |
| t_body_reader(lambda: ContentLengthReader(0), b"", [EndOfMessage()]) | |
| t_body_reader( | |
| lambda: ContentLengthReader(10), | |
| b"0123456789", | |
| [Data(data=b"0123456789"), EndOfMessage()], | |
| ) | |
| def test_Http10Reader() -> None: | |
| t_body_reader(Http10Reader, b"", [EndOfMessage()], do_eof=True) | |
| t_body_reader(Http10Reader, b"asdf", [Data(data=b"asdf")], do_eof=False) | |
| t_body_reader( | |
| Http10Reader, b"asdf", [Data(data=b"asdf"), EndOfMessage()], do_eof=True | |
| ) | |
| def test_ChunkedReader() -> None: | |
| t_body_reader(ChunkedReader, b"0\r\n\r\n", [EndOfMessage()]) | |
| t_body_reader( | |
| ChunkedReader, | |
| b"0\r\nSome: header\r\n\r\n", | |
| [EndOfMessage(headers=[("Some", "header")])], | |
| ) | |
| t_body_reader( | |
| ChunkedReader, | |
| b"5\r\n01234\r\n" | |
| + b"10\r\n0123456789abcdef\r\n" | |
| + b"0\r\n" | |
| + b"Some: header\r\n\r\n", | |
| [ | |
| Data(data=b"012340123456789abcdef"), | |
| EndOfMessage(headers=[("Some", "header")]), | |
| ], | |
| ) | |
| t_body_reader( | |
| ChunkedReader, | |
| b"5\r\n01234\r\n" + b"10\r\n0123456789abcdef\r\n" + b"0\r\n\r\n", | |
| [Data(data=b"012340123456789abcdef"), EndOfMessage()], | |
| ) | |
| # handles upper and lowercase hex | |
| t_body_reader( | |
| ChunkedReader, | |
| b"aA\r\n" + b"x" * 0xAA + b"\r\n" + b"0\r\n\r\n", | |
| [Data(data=b"x" * 0xAA), EndOfMessage()], | |
| ) | |
| # refuses arbitrarily long chunk integers | |
| with pytest.raises(LocalProtocolError): | |
| # Technically this is legal HTTP/1.1, but we refuse to process chunk | |
| # sizes that don't fit into 20 characters of hex | |
| t_body_reader(ChunkedReader, b"9" * 100 + b"\r\nxxx", [Data(data=b"xxx")]) | |
| # refuses garbage in the chunk count | |
| with pytest.raises(LocalProtocolError): | |
| t_body_reader(ChunkedReader, b"10\x00\r\nxxx", None) | |
| # handles (and discards) "chunk extensions" omg wtf | |
| t_body_reader( | |
| ChunkedReader, | |
| b"5; hello=there\r\n" | |
| + b"xxxxx" | |
| + b"\r\n" | |
| + b'0; random="junk"; some=more; canbe=lonnnnngg\r\n\r\n', | |
| [Data(data=b"xxxxx"), EndOfMessage()], | |
| ) | |
| t_body_reader( | |
| ChunkedReader, | |
| b"5 \r\n01234\r\n" + b"0\r\n\r\n", | |
| [Data(data=b"01234"), EndOfMessage()], | |
| ) | |
| def test_ContentLengthWriter() -> None: | |
| w = ContentLengthWriter(5) | |
| assert dowrite(w, Data(data=b"123")) == b"123" | |
| assert dowrite(w, Data(data=b"45")) == b"45" | |
| assert dowrite(w, EndOfMessage()) == b"" | |
| w = ContentLengthWriter(5) | |
| with pytest.raises(LocalProtocolError): | |
| dowrite(w, Data(data=b"123456")) | |
| w = ContentLengthWriter(5) | |
| dowrite(w, Data(data=b"123")) | |
| with pytest.raises(LocalProtocolError): | |
| dowrite(w, Data(data=b"456")) | |
| w = ContentLengthWriter(5) | |
| dowrite(w, Data(data=b"123")) | |
| with pytest.raises(LocalProtocolError): | |
| dowrite(w, EndOfMessage()) | |
| w = ContentLengthWriter(5) | |
| dowrite(w, Data(data=b"123")) == b"123" | |
| dowrite(w, Data(data=b"45")) == b"45" | |
| with pytest.raises(LocalProtocolError): | |
| dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) | |
| def test_ChunkedWriter() -> None: | |
| w = ChunkedWriter() | |
| assert dowrite(w, Data(data=b"aaa")) == b"3\r\naaa\r\n" | |
| assert dowrite(w, Data(data=b"a" * 20)) == b"14\r\n" + b"a" * 20 + b"\r\n" | |
| assert dowrite(w, Data(data=b"")) == b"" | |
| assert dowrite(w, EndOfMessage()) == b"0\r\n\r\n" | |
| assert ( | |
| dowrite(w, EndOfMessage(headers=[("Etag", "asdf"), ("a", "b")])) | |
| == b"0\r\nEtag: asdf\r\na: b\r\n\r\n" | |
| ) | |
| def test_Http10Writer() -> None: | |
| w = Http10Writer() | |
| assert dowrite(w, Data(data=b"1234")) == b"1234" | |
| assert dowrite(w, EndOfMessage()) == b"" | |
| with pytest.raises(LocalProtocolError): | |
| dowrite(w, EndOfMessage(headers=[("Etag", "asdf")])) | |
| def test_reject_garbage_after_request_line() -> None: | |
| with pytest.raises(LocalProtocolError): | |
| tr(READERS[SERVER, SEND_RESPONSE], b"HTTP/1.0 200 OK\x00xxxx\r\n\r\n", None) | |
| def test_reject_garbage_after_response_line() -> None: | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1 xxxxxx\r\n" b"Host: a\r\n\r\n", | |
| None, | |
| ) | |
| def test_reject_garbage_in_header_line() -> None: | |
| with pytest.raises(LocalProtocolError): | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" b"Host: foo\x00bar\r\n\r\n", | |
| None, | |
| ) | |
| def test_reject_non_vchar_in_path() -> None: | |
| for bad_char in b"\x00\x20\x7f\xee": | |
| message = bytearray(b"HEAD /") | |
| message.append(bad_char) | |
| message.extend(b" HTTP/1.1\r\nHost: foobar\r\n\r\n") | |
| with pytest.raises(LocalProtocolError): | |
| tr(READERS[CLIENT, IDLE], message, None) | |
| # https://github.com/python-hyper/h11/issues/57 | |
| def test_allow_some_garbage_in_cookies() -> None: | |
| tr( | |
| READERS[CLIENT, IDLE], | |
| b"HEAD /foo HTTP/1.1\r\n" | |
| b"Host: foo\r\n" | |
| b"Set-Cookie: ___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900\r\n" | |
| b"\r\n", | |
| Request( | |
| method="HEAD", | |
| target="/foo", | |
| headers=[ | |
| ("Host", "foo"), | |
| ("Set-Cookie", "___utmvafIumyLc=kUd\x01UpAt; path=/; Max-Age=900"), | |
| ], | |
| ), | |
| ) | |
| def test_host_comes_first() -> None: | |
| tw( | |
| write_headers, | |
| normalize_and_validate([("foo", "bar"), ("Host", "example.com")]), | |
| b"Host: example.com\r\nfoo: bar\r\n\r\n", | |
| ) | |