Skip to content

Commit

Permalink
Redesign extensions
Browse files Browse the repository at this point in the history
  • Loading branch information
User committed Jul 1, 2023
1 parent f92c5da commit 3a75a7f
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 43 deletions.
2 changes: 1 addition & 1 deletion Makefile
Expand Up @@ -27,7 +27,7 @@ pytest:
pytest -n30 -x --cov grab --cov-report term-missing

test: check pytest
tox -e py38-check
tox -e python38-check

#release:
# git push \
Expand Down
67 changes: 50 additions & 17 deletions grab/base.py
Expand Up @@ -2,7 +2,7 @@

import typing
from abc import ABCMeta, abstractmethod
from collections.abc import Callable, Generator, Mapping, MutableMapping
from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Generic, TypeVar, cast
Expand Down Expand Up @@ -44,27 +44,54 @@ class BaseResponse:

class BaseExtension(Generic[RequestT, ResponseT], metaclass=ABCMeta):
ext_handlers: Mapping[str, Callable[..., Any]] = {}
__slots__ = ()

def __set_name__(self, owner: BaseClient[RequestT, ResponseT], attr: str) -> None:
owner.extensions[attr] = {
"instance": self,
}
for point_name, func in self.ext_handlers.items():
owner.ext_handlers[point_name].append(func)
registry: MutableMapping[
str,
tuple[
type[BaseClient[RequestT, ResponseT]], BaseExtension[RequestT, ResponseT]
],
] = {}
__slots__ = ["owners"]

def __set_name__(
self, owner: type[BaseClient[RequestT, ResponseT]], attr: str
) -> None:
self.registry[attr] = (owner, self)

@abstractmethod
def reset(self) -> None:
...

@classmethod
def get_extensions(
cls, obj: BaseClient[RequestT, ResponseT]
) -> Sequence[tuple[str, BaseExtension[RequestT, ResponseT]]]:
owner_reg: MutableMapping[
type[BaseClient[RequestT, ResponseT]],
list[tuple[str, BaseExtension[RequestT, ResponseT]]],
] = {}
for ext_key, ext_tuple in cls.registry.items():
owner_type, ext = ext_tuple
owner_reg.setdefault(owner_type, []).append((ext_key, ext))
ret = []
stack = [obj.__class__]
while stack:
ptr = stack.pop()
if ptr in owner_reg:
ext_list = owner_reg[ptr]
ret.extend(ext_list)
for base in ptr.__bases__:
if base != object().__class__:
stack.append(base)
return ret


class Retry:
def __init__(self) -> None:
self.state: MutableMapping[str, int] = {}


class BaseClient(Generic[RequestT, ResponseT], metaclass=ABCMeta):
__slots__ = ["transport"]
__slots__ = ["transport", "ext_handlers"]
transport: BaseTransport[RequestT, ResponseT]

@property
Expand All @@ -78,20 +105,26 @@ def default_transport_class(self) -> type[BaseTransport[RequestT, ResponseT]]:
...

extensions: MutableMapping[str, MutableMapping[str, Any]] = {}
ext_handlers: Mapping[str, list[Callable[..., Any]]] = {
"request:pre": [],
"request_cookies": [],
"response:post": [],
"init-retry": [],
"retry": [],
}
ext_handlers: Mapping[str, list[Callable[..., Any]]]

def __init__(
self,
transport: None
| BaseTransport[RequestT, ResponseT]
| type[BaseTransport[RequestT, ResponseT]] = None,
):
self.ext_handlers = {
"request:pre": [],
"request_cookies": [],
"response:post": [],
"init-retry": [],
"retry": [],
}
for ext_key, _ext_proxy in BaseExtension.get_extensions(self):
ext = getattr(self, ext_key)
print(self, ext_key, _ext_proxy, ext)
for point_name, func in ext.ext_handlers.items():
self.ext_handlers[point_name].append(func)
self.transport = self.default_transport_class.resolve_entity(
transport, self.default_transport_class
)
Expand Down
75 changes: 54 additions & 21 deletions grab/extensions.py
@@ -1,25 +1,32 @@
from __future__ import annotations

import weakref
from collections.abc import Mapping, MutableMapping
from http.cookiejar import Cookie, CookieJar
from typing import Any, cast
from urllib.parse import urljoin, urlsplit

from .base import BaseExtension
from .base import BaseClient, BaseExtension
from .document import Document
from .errors import GrabTooManyRedirectsError
from .request import HttpRequest
from .util.cookies import build_cookie_header, build_jar, create_cookie


class RedirectExtension(BaseExtension[HttpRequest, Document]):
def __init__(self, cookiejar: None | CookieJar = None) -> None:
self.cookiejar = cookiejar if cookiejar else CookieJar()
def __init__(self) -> None:
self.ext_handlers = {
"init-retry": self.process_init_retry,
"retry": self.process_retry,
}

def __get__(
self,
obj: BaseClient[HttpRequest, Document],
objtype: None | type[BaseClient[HttpRequest, Document]] = None,
) -> RedirectExtension:
return self

def find_redirect_url(self, doc: Document) -> None | str:
assert doc.headers is not None
if doc.code in {301, 302, 303, 307, 308} and doc.headers["Location"]:
Expand Down Expand Up @@ -48,22 +55,43 @@ def process_retry(
return None, None


class CookiesExtension(BaseExtension[HttpRequest, Document]):
class CookiesStore:
__slots__ = ("cookiejar", "ext_handlers")

def __init__(self, cookiejar: None | CookieJar = None) -> None:
self.cookiejar = cookiejar if cookiejar else CookieJar()
self.ext_handlers = {
"request:pre": self.process_request_pre,
"response:post": self.process_response_post,
}

def process_request_pre(self, req: HttpRequest) -> None:
self.update(req.cookies, req.url)
if hdr := build_cookie_header(self.cookiejar, req.url, req.headers):
if req.headers.get("Cookie"):
raise ValueError(
"Could not configure request with session cookies"
" because it has already Cookie header"
)
req.cookie_header = hdr

def process_response_post(
self, req: HttpRequest, doc: Document # pylint: disable=unused-argument
) -> None:
for item in doc.cookies:
self.cookiejar.set_cookie(item)

def reset(self) -> None:
self.clear()

def set_cookie(self, cookie: Cookie) -> None:
self.cookiejar.set_cookie(cookie)

def clear(self) -> None:
"""Clear all remembered cookies."""
self.cookiejar.clear()

def clone(self) -> CookiesExtension:
def clone(self) -> CookiesStore:
return self.__class__(build_jar(list(self.cookiejar)))

def update(self, cookies: Mapping[str, Any], request_url: str) -> None:
Expand All @@ -78,7 +106,8 @@ def update(self, cookies: Mapping[str, Any], request_url: str) -> None:

def __getstate__(self) -> MutableMapping[str, Any]:
state = {}
for name, value in self.__dict__.items():
for name in self.__slots__:
value = getattr(self, name)
if name == "cookiejar":
state["_cookiejar_items"] = list(value)
else:
Expand All @@ -92,21 +121,25 @@ def __setstate__(self, state: Mapping[str, Any]) -> None:
else:
setattr(self, name, value)

def process_request_pre(self, req: HttpRequest) -> None:
self.update(req.cookies, req.url)
if hdr := build_cookie_header(self.cookiejar, req.url, req.headers):
if req.headers.get("Cookie"):
raise ValueError(
"Could not configure request with session cookies"
" because it has already Cookie header"
)
req.cookie_header = hdr

def process_response_post(
self, req: HttpRequest, doc: Document # pylint: disable=unused-argument
) -> None:
for item in doc.cookies:
self.cookiejar.set_cookie(item)
class CookiesExtension(BaseExtension[HttpRequest, Document]):
__slots__ = []

owner_store_reg: MutableMapping[
BaseClient[HttpRequest, Document], CookiesStore
] = {}

def __init__(self) -> None:
self.owners: weakref.WeakKeyDictionary[
BaseClient[HttpRequest, Document], CookiesStore
] = weakref.WeakKeyDictionary()

def __get__(
self,
obj: BaseClient[HttpRequest, Document],
objtype: None | type[BaseClient[HttpRequest, Document]] = None,
) -> CookiesStore:
return self.owners.setdefault(obj, CookiesStore())

def reset(self) -> None:
self.clear()
pass
31 changes: 30 additions & 1 deletion tests/test_grab_cookies.py
Expand Up @@ -17,7 +17,7 @@ def test_parsing_response_cookies(self) -> None:

def test_multiple_cookies(self) -> None:
self.server.add_response(Response())
request(self.server.get_url(), cookies={"foo": "1", "bar": "2"})
request(self.server.get_url(), client=Grab, cookies={"foo": "1", "bar": "2"})
self.assertEqual(
{(x.key, x.value) for x in self.server.request.cookies.values()},
{("foo", "1"), ("bar", "2")},
Expand Down Expand Up @@ -199,3 +199,32 @@ def callback() -> bytes:
# request page one more time, sending cookie
# should not fail
request(self.server.get_url())

def test_different_instances(self) -> None:
grab1 = Grab()
self.server.add_response(Response(headers=[("Set-Cookie", "key1=val1")]))
doc1 = grab1.request(self.server.get_url())
self.assertTrue(
all(x.name == "key1" and x.value == "val1" for x in doc1.cookies)
)
self.assertTrue(
all(x.name == "key1" and x.value == "val1" for x in grab1.cookies.cookiejar)
)

grab2 = Grab()
self.server.add_response(Response(headers=[("Set-Cookie", "key2=val2")]))
doc2 = grab2.request(self.server.get_url())
self.assertTrue(
all(x.name == "key2" and x.value == "val2" for x in doc2.cookies)
)
self.assertTrue(
all(x.name == "key2" and x.value == "val2" for x in grab2.cookies.cookiejar)
)

# double check grab1
self.assertTrue(
all(x.name == "key1" and x.value == "val1" for x in doc1.cookies)
)
self.assertTrue(
all(x.name == "key1" and x.value == "val1" for x in grab1.cookies.cookiejar)
)
6 changes: 3 additions & 3 deletions tox.ini
Expand Up @@ -16,7 +16,7 @@ deps =
commands =
make pytest

[testenv:py38-test]
[testenv:python38-test]
commands =
make pytest
basepython=/opt/python38/bin/python3.8
Expand All @@ -30,7 +30,7 @@ commands =
make flake8
echo "OK"

[testenv:py38-check]
[testenv:python38-check]
commands =
python -V
make check
Expand All @@ -41,7 +41,7 @@ commands =
#echo "OK"
basepython=/opt/python38/bin/python3.8

[testenv:py38-mypy]
[testenv:python38-mypy]
commands =
python -V
make mypy
Expand Down

0 comments on commit 3a75a7f

Please sign in to comment.