Skip to content

Commit

Permalink
chore: better types for funcs (#743)
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaraditya303 authored Jun 7, 2021
1 parent f720f19 commit 21b7d1c
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
15 changes: 7 additions & 8 deletions playwright/_impl/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ class Connection:
def __init__(
self,
dispatcher_fiber: Any,
object_factory: Callable[[ChannelOwner, str, str, Dict], Any],
object_factory: Callable[[ChannelOwner, str, str, Dict], ChannelOwner],
transport: Transport,
) -> None:
self._dispatcher_fiber = dispatcher_fiber
self._transport = transport
self._transport.on_message = lambda msg: self._dispatch(msg)
self._waiting_for_object: Dict[str, Any] = {}
self._waiting_for_object: Dict[str, Callable[[ChannelOwner], None]] = {}
self._last_id = 0
self._objects: Dict[str, ChannelOwner] = {}
self._callbacks: Dict[int, ProtocolCallback] = {}
Expand Down Expand Up @@ -189,19 +189,19 @@ def cleanup(self) -> None:
for ws_connection in self._child_ws_connections:
ws_connection._transport.dispose()

async def wait_for_object_with_known_name(self, guid: str) -> Any:
async def wait_for_object_with_known_name(self, guid: str) -> ChannelOwner:
if guid in self._objects:
return self._objects[guid]
callback = self._loop.create_future()
callback: asyncio.Future[ChannelOwner] = self._loop.create_future()

def callback_wrapper(result: Any) -> None:
def callback_wrapper(result: ChannelOwner) -> None:
callback.set_result(result)

self._waiting_for_object[guid] = callback_wrapper
return await callback

def call_on_object_with_known_name(
self, guid: str, callback: Callable[[Any], None]
self, guid: str, callback: Callable[[ChannelOwner], None]
) -> None:
self._waiting_for_object[guid] = callback

Expand Down Expand Up @@ -279,8 +279,7 @@ def _dispatch(self, msg: ParsedMessagePayload) -> None:

def _create_remote_object(
self, parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> Any:
result: ChannelOwner
) -> ChannelOwner:
initializer = self._replace_guids_with_channels(initializer)
result = self._object_factory(parent, type, guid, initializer)
if guid in self._waiting_for_object:
Expand Down
4 changes: 2 additions & 2 deletions playwright/_impl/_object_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Dict, cast
from typing import Dict, cast

from playwright._impl._artifact import Artifact
from playwright._impl._browser import Browser
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(

def create_remote_object(
parent: ChannelOwner, type: str, guid: str, initializer: Dict
) -> Any:
) -> ChannelOwner:
if type == "Artifact":
return Artifact(parent, type, guid, initializer)
if type == "BindingCall":
Expand Down
7 changes: 4 additions & 3 deletions playwright/_impl/_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
import sys
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Callable, Dict, Optional

import websockets
from pyee import AsyncIOEventEmitter

from playwright._impl._api_types import Error
from playwright._impl._helper import ParsedMessagePayload


# Sourced from: https://github.com/pytest-dev/pytest/blob/da01ee0a4bb0af780167ecd228ab3ad249511302/src/_pytest/faulthandler.py#L69-L77
Expand All @@ -44,7 +45,7 @@ def _get_stderr_fileno() -> Optional[int]:
class Transport(ABC):
def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._loop = loop
self.on_message = lambda _: None
self.on_message: Callable[[ParsedMessagePayload], None] = lambda _: None
self.on_error_future: asyncio.Future = loop.create_future()

@abstractmethod
Expand Down Expand Up @@ -72,7 +73,7 @@ def serialize_message(self, message: Dict) -> bytes:
print("\x1b[32mSEND>\x1b[0m", json.dumps(message, indent=2))
return msg.encode()

def deserialize_message(self, data: bytes) -> Any:
def deserialize_message(self, data: bytes) -> ParsedMessagePayload:
obj = json.loads(data)

if "DEBUGP" in os.environ: # pragma: no cover
Expand Down

0 comments on commit 21b7d1c

Please sign in to comment.