Skip to content

Commit

Permalink
feat: more tests & update db_json.
Browse files Browse the repository at this point in the history
  • Loading branch information
andoludo committed Mar 10, 2024
1 parent 8ba2669 commit 6c33830
Show file tree
Hide file tree
Showing 10 changed files with 676 additions and 76 deletions.
34 changes: 7 additions & 27 deletions bearish/scrapers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@

import pandas as pd
import simplejson
import undetected_chromedriver as uc # type: ignore
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, computed_field
from selenium.common import MoveTargetOutOfBoundsException, TimeoutException
from selenium.webdriver import ActionChains, Chrome, Keys
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.chrome.webdriver import WebDriver
from selenium.webdriver.remote.webdriver import WebDriver as BaseWebDriver
from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.support import expected_conditions
from selenium.webdriver.support.wait import WebDriverWait

from bearish.scrapers.model import HistoricalData
from bearish.scrapers.model import HistoricalData, _clean
from bearish.scrapers.settings import InvestingCountry, TradingCountry
from bearish.scrapers.type import Locator

Expand Down Expand Up @@ -85,13 +85,11 @@ def move_by_x_offset_from_left_border(element: BaseElement, x_offset: int) -> bo
return right_border


def init_chrome(load_strategy_none: bool = False, headless: bool = False) -> Chrome:
option = Options()
def init_chrome(headless: bool = True) -> uc.Chrome:
options = {}
if headless:
option.add_argument("--headless")
if load_strategy_none:
option.page_load_strategy = "none"
return Chrome(options=option)
options.update({"headless": True})
return uc.Chrome(use_subprocess=False, version_main=121, **options)


def bearish_path_fun() -> Path:
Expand All @@ -105,16 +103,6 @@ class BaseSettings(BaseModel):
...


def clean_dict(data: Dict[str, Any]) -> Dict[str, Any]:
cleaned_data = {}
for name, value in data.items():
if isinstance(value, dict):
cleaned_data[str(name)] = clean_dict(value)
else:
cleaned_data[str(name)] = value
return cleaned_data


def _replace_values(
tables: list[pd.DataFrame], replace_values: Dict[str, str]
) -> list[pd.DataFrame]:
Expand Down Expand Up @@ -166,15 +154,6 @@ def _get_country_name_per_enum(
)


def _clean(
data: List[Dict[str, Any]] | Dict[str, Any]
) -> List[Dict[str, Any]] | Dict[str, Any]:
if isinstance(data, list):
return [clean_dict(data_) for data_ in data]
else:
return clean_dict(data)


class CountryNameMixin:
@abc.abstractmethod
def _get_country_name(self) -> str:
Expand All @@ -187,6 +166,7 @@ class BasePage(BaseModel):
settings: BaseSettings
browser: WebDriver = Field(default_factory=init_chrome, description="")
bearish_path: Path = Field(default_factory=bearish_path_fun, description="")
first_page_only: Optional[bool] = False
model_config = ConfigDict(arbitrary_types_allowed=True, use_enum_values=True)
_tables = PrivateAttr(default_factory=list)
_skip_existing = PrivateAttr(default=True)
Expand Down
19 changes: 14 additions & 5 deletions bearish/scrapers/investing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import datetime
from functools import partial
from typing import Any, Dict, List, Literal

Expand All @@ -24,6 +25,8 @@
from bearish.scrapers.settings import InvestingCountry
from bearish.scrapers.type import Locator

ONE_PAGE = 3

COLUMNS_LENGTH = 2


Expand Down Expand Up @@ -78,12 +81,20 @@ def get_statements_urls(self, exchange: str) -> List[str]:
]


class UpdateInvestingSettings(InvestingSettings):
start_date: str = Field(
default_factory=lambda: (
datetime.date.today() - datetime.timedelta(days=1)
).strftime("%d-%m-%Y")
)


class InvestingScreenerScraper(BasePage, CountryNameMixin):
country: int
settings: InvestingSettings = Field(default=InvestingSettings())
source: Literal["trading", "investing", "yahoo"] = "investing"
browser: WebDriver = Field(
default_factory=lambda: init_chrome(load_strategy_none=True, headless=True),
default_factory=lambda: init_chrome(headless=True),
description="",
)

Expand Down Expand Up @@ -133,6 +144,8 @@ def read_next_pages(self) -> None:
except (ElementClickInterceptedException, TimeoutException):
break
page_number += 1
if (page_number == ONE_PAGE) and self.first_page_only:
break

def _custom_scrape(self) -> list[dict[str, Any]]:
self.click_one_trust_button()
Expand All @@ -145,10 +158,6 @@ class InvestingTickerScraper(BaseTickerPage):
exchange: str
source: Literal["trading", "investing", "yahoo"] = "investing"
settings: InvestingSettings = Field(default=InvestingSettings())
browser: WebDriver = Field(
default_factory=lambda: init_chrome(load_strategy_none=True, headless=False),
description="",
)

@model_validator(mode="before")
@classmethod
Expand Down
57 changes: 48 additions & 9 deletions bearish/scrapers/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from typing import Any, Dict, Literal, Optional, Type, Union

from pydantic import BaseModel, ConfigDict, Field
from selenium.webdriver.chrome.webdriver import WebDriver

from bearish.scrapers.base import BasePage, bearish_path_fun
from bearish.scrapers.base import BasePage, BaseSettings, bearish_path_fun, init_chrome
from bearish.scrapers.investing import InvestingScreenerScraper, InvestingTickerScraper
from bearish.scrapers.model import Ticker, merge, unflatten_json
from bearish.scrapers.settings import InvestingCountry, TradingCountry
Expand Down Expand Up @@ -48,38 +49,59 @@ class Scraper(BaseModel):
)
source: Source
country: Literal["germany", "france", "belgium", "usa"]
settings: Optional[BaseSettings] = None
browser: WebDriver = Field(default_factory=init_chrome, description="")

def _screener_scraper(self) -> BasePage:
def _screener_scraper(self, first_page_only: bool = False) -> BasePage:
return self.source.screener( # type: ignore
country=getattr(self.source.country, self.country),
bearish_path=self.bearish_path,
first_page_only=first_page_only,
settings=self.settings,
browser=self.browser,
)

def scrape(
self, skip_existing: bool = True, symbols: Optional[list[str]] = None
self,
skip_existing: bool = True,
symbols: Optional[list[str]] = None,
first_page_only: bool = False,
) -> None:
screener_scraper = self._screener_scraper()
screener_scraper = self._screener_scraper(first_page_only=first_page_only)
screener_scraper.scrape(skip_existing=skip_existing)
tickers = Ticker.from_json(screener_scraper.get_stored_raw())
tickers = Ticker.from_json(
screener_scraper.get_stored_raw(), source=self.source.screener.source
)
tickers = _filter_by_symbols(tickers=tickers, symbols=symbols)
for ticker in tickers:
scraper = self.source.ticker( # type: ignore
exchange=ticker.reference, bearish_path=self.bearish_path
browser=self.browser,
exchange=ticker.reference,
bearish_path=self.bearish_path,
settings=self.settings,
)
try:
scraper.scrape(skip_existing=skip_existing)
except Exception as e:
logger.error(f"Fail {ticker.reference}. reason: {e}")

def create_db_json(self) -> list[Dict[str, Any]]:
def create_db_json(
self, symbols: Optional[list[str]] = None
) -> list[Dict[str, Any]]:
scraper = self._screener_scraper()
if not scraper.get_stored_raw().exists():
return []
tickers = Ticker.from_json(scraper.get_stored_raw())
tickers = Ticker.from_json(
scraper.get_stored_raw(), source=self.source.ticker.source
)
db_json = []
tickers = _filter_by_symbols(tickers=tickers, symbols=symbols)
for ticker in tickers:
ticker_scraper = self.source.ticker( # type: ignore
browser=None, exchange=ticker.reference, bearish_path=self.bearish_path
browser=self.browser,
exchange=ticker.reference,
bearish_path=self.bearish_path,
settings=self.settings,
)
if not ticker_scraper.get_stored_raw().exists():
continue
Expand All @@ -89,3 +111,20 @@ def create_db_json(self) -> list[Dict[str, Any]]:
merge(Ticker, ticker, ticker_)
db_json.append(ticker.model_dump())
return db_json

def update_db_json(self, db_json_path: Path) -> None:
db_json = json.loads(db_json_path.read_text())
tickers = [Ticker(**ticker_json) for ticker_json in db_json]
for ticker in tickers:
ticker_scraper = self.source.ticker( # type: ignore
browser=self.browser,
exchange=ticker.reference,
bearish_path=self.bearish_path,
settings=self.settings,
)
if ticker_scraper.source != ticker.source:
continue
records = ticker_scraper.scrape(skip_existing=False)
if not records:
continue
Ticker.from_record(records, source=ticker.source)
37 changes: 33 additions & 4 deletions bearish/scrapers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Type, Union

from pydantic import (
AliasChoices,
Expand Down Expand Up @@ -358,7 +358,7 @@ class Ticker(BaseTickerModel):
default=None, validation_alias=AliasChoices("Name", "name")
)
symbol: Optional[str] = Field(default=None, validation_alias=AliasChoices("Symbol"))
source: Optional[str] = None
source: Literal["trading", "investing", "yahoo"]
sector: Optional[str] = Field(default=None, validation_alias=AliasChoices("Sector"))
reference: Optional[str] = None
industry: Optional[str] = Field(
Expand All @@ -378,9 +378,19 @@ def reference_validator(cls, value: str) -> str:
return value

@classmethod
def from_json(cls, path: Path) -> List["Ticker"]:
def from_json(
cls, path: Path, source: Literal["trading", "investing", "yahoo"]
) -> List["Ticker"]:
records = json.loads(Path(path).read_text())
return [cls(**unflatten_json(cls, record)) for record in records]
return [cls.from_record(record, source) for record in records]

@classmethod
def from_record(
cls,
record: Dict[str, Any] | list[Dict[str, Any]],
source: Literal["trading", "investing", "yahoo"],
) -> "Ticker":
return cls(**(unflatten_json(cls, _clean(record) | {"source": source}))) # type: ignore


def is_nested(schema: Type[BaseModel]) -> bool:
Expand Down Expand Up @@ -422,3 +432,22 @@ def unflatten_json(schema: Type[BaseModel], data: Dict[str, Any]) -> Dict[str, A
original_data[name] = unflatten_json(field.annotation, data)
copy_data.update(original_data)
return schema(**copy_data).model_dump()


def _clean(
data: List[Dict[str, Any]] | Dict[str, Any]
) -> List[Dict[str, Any]] | Dict[str, Any]:
if isinstance(data, list):
return [clean_dict(data_) for data_ in data]
else:
return clean_dict(data)


def clean_dict(data: Dict[str, Any]) -> Dict[str, Any]:
cleaned_data = {}
for name, value in data.items():
if isinstance(value, dict):
cleaned_data[str(name)] = clean_dict(value)
else:
cleaned_data[str(name)] = value
return cleaned_data
20 changes: 20 additions & 0 deletions bearish/tests/scrapers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import pytest

from bearish.scrapers.investing import InvestingSettings, UpdateInvestingSettings


@pytest.fixture(scope="session")
def screener_investing() -> Path:
Expand Down Expand Up @@ -39,3 +41,21 @@ def investing_record(investing_records: list[dict]) -> dict:
@pytest.fixture(scope="session")
def trading_record(trading_records: list[dict]) -> dict:
return trading_records[0]


@pytest.fixture
def invest_settings() -> InvestingSettings:
return InvestingSettings(
suffixes=[
"-income-statement",
]
)


@pytest.fixture
def update_invest_settings() -> UpdateInvestingSettings:
return UpdateInvestingSettings(
suffixes=[
"-income-statement",
]
)
Loading

0 comments on commit 6c33830

Please sign in to comment.