diff --git a/.env.example b/.env.example index e1f4ff4..f94328a 100644 --- a/.env.example +++ b/.env.example @@ -9,3 +9,10 @@ X_BOOKING_TOPIC= # User Agent USER_AGENT= + +# Database +POSTGRES_USER=postgres +POSTGRES_PASSWORD=postgres +POSTGRES_HOST=localhost +POSTGRES_PORT=5500 +POSTGRES_DB=postgres \ No newline at end of file diff --git a/README.md b/README.md index 7aed644..4049a1b 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,13 @@ Built on top of [Find the Hotel's Average Room Price in Osaka](#find-the-hotels- - Create a virtual environment and activate it. - Install all dependencies listed in [requirements.txt](requirements.txt) - Rename a `.env.example` to `.env` + +### Setup a Database +- Download [Docker Desktop](https://www.docker.com/products/docker-desktop) +- Ensure that Docker Desktop is running. +- Run: `export POSTGRES_DATA_PATH=''` to set the container volume + to the directory path of your choice. +- Run: `docker compose up -d` ### Find your **User Agent**: - Go to https://www.whatismybrowser.com/detect/what-is-my-user-agent/ @@ -130,41 +137,36 @@ Built on top of [Find the Hotel's Average Room Price in Osaka](#find-the-hotels- ### General Guidelines for Using the Scraper - To scrape only hotel properties, use `--scrape_only_hotel` argument. -- The SQLite database is created automatically if it doesn't exist. - +- Ensure that Docker Desktop and Postgres container are running. + ### To scrape using Whole-Month GraphQL Scraper: - Example usage, with only required arguments for Whole-Month GraphQL Scraper: ```bash - python main.py --whole_mth --year=2024 --month=12 --city=Osaka \ - --sqlite_name=avg_japan_hotel_price_test.db + python main.py --whole_mth --year=2024 --month=12 --city=Osaka ``` - Scrape data start from the given day of the month to the end of the same month. - Default **start day** is 1. - **Start day** can be set with `--start_day` argument. -- Data is saved to **SQLite**. ### To scrape using Basic GraphQL Scraper: - Example usage, with only required arguments for Basic GraphQL Scraper: ```bash - python main.py --city=Osaka --check_in=2024-12-25 --check_out=2024-12-26 --scraper \ - --sqlite_name=avg_japan_hotel_price_test.db + python main.py --city=Osaka --check_in=2024-12-25 --check_out=2024-12-26 --scraper ``` -- Data is saved to **SQLite**. ### To scrape using Japan GraphQL Scraper: - Example usage, with only required arguments for Japan GraphQL Scraper: ```bash - python main.py --japan_hotel --sqlite_name=japan_hotel_data_test.db + python main.py --japan_hotel ``` -- Data is saved to **SQLite**. - Prefecture to scrape can be specified with `--prefecture` argument, for example: - ```bash - python main.py --japan_hotel --prefecture Tokyo --sqlite_name=japan_hotel_data_test.db + python main.py --japan_hotel --prefecture Tokyo ``` - If `--prefecture` argument is not specified, all prefectures will be scraped. - Multiple prefectures can be specified. - ```bash - python main.py --japan_hotel --prefecture Tokyo Osaka --sqlite_name=japan_hotel_data_test.db + python main.py --japan_hotel --prefecture Tokyo Osaka ``` - You can use the prefecture name on Booking.com as a reference. diff --git a/check_missing_dates.py b/check_missing_dates.py index 116948b..0ad493a 100644 --- a/check_missing_dates.py +++ b/check_missing_dates.py @@ -2,11 +2,13 @@ import asyncio import calendar import datetime +import os from calendar import monthrange from dataclasses import dataclass, field from typing import Any -from sqlalchemy import create_engine, func +from dotenv import load_dotenv +from sqlalchemy import create_engine, func, Engine from sqlalchemy.orm import sessionmaker from japan_avg_hotel_price_finder.booking_details import BookingDetails @@ -16,6 +18,14 @@ from japan_avg_hotel_price_finder.sql.db_model import HotelPrice from japan_avg_hotel_price_finder.sql.save_to_db import save_scraped_data +load_dotenv(dotenv_path='.env') + +postgres_host = os.getenv('POSTGRES_HOST') +postgres_port = os.getenv('POSTGRES_PORT') +postgres_user = os.getenv('POSTGRES_USER') +postgres_password = os.getenv('POSTGRES_PASSWORD') +postgres_db = os.getenv('POSTGRES_DB') + def find_missing_dates(dates_in_db: set[str], days_in_month: int, @@ -81,12 +91,14 @@ def filter_past_date(dates_in_db_date_obj: list[datetime.date], today: datetime. async def scrape_missing_dates(missing_dates_list: list[str] = None, booking_details_class: 'BookingDetails' = None, - country: str = 'Japan') -> None: + country: str = 'Japan', + engine: Engine = None) -> None: """ Scrape missing dates with BasicScraper and load them into a database. :param missing_dates_list: Missing dates. :param booking_details_class: Dataclass of booking details as parameters, default is None. :param country: Country where the hotels are located, default is Japan. + :param engine: SQLAlchemy engine. :return: None """ main_logger.info("Scraping missing dates...") @@ -106,15 +118,14 @@ async def scrape_missing_dates(missing_dates_list: list[str] = None, num_rooms = booking_details_class.num_rooms selected_currency = booking_details_class.selected_currency scrape_only_hotel = booking_details_class.scrape_only_hotel - sqlite_name = booking_details_class.sqlite_name scraper = BasicGraphQLScraper(check_in=check_in, check_out=check_out, city=city, group_adults=group_adults, group_children=group_children, num_rooms=num_rooms, selected_currency=selected_currency, - scrape_only_hotel=scrape_only_hotel, sqlite_name=sqlite_name, country=country) + scrape_only_hotel=scrape_only_hotel, country=country) df = await scraper.scrape_graphql() - save_scraped_data(dataframe=df, db=scraper.sqlite_name) + save_scraped_data(dataframe=df, engine=engine) else: main_logger.warning("Missing dates is None. No missing dates to scrape.") @@ -126,18 +137,15 @@ class MissingDateChecker: It only checks the data scraped today, UTC Time. Attributes: - sqlite_name (str): Path to SQLite database. city (str): City where the hotels are located. """ - sqlite_name: str city: str # sqlalchemy - engine: Any = field(init=False) + engine: Any = field(init=True) Session: Any = field(init=False) def __post_init__(self): - self.engine = create_engine(f'sqlite:///{self.sqlite_name}') self.Session = sessionmaker(bind=self.engine) def find_missing_dates_in_db(self, year: int) -> list[str]: @@ -146,7 +154,7 @@ def find_missing_dates_in_db(self, year: int) -> list[str]: :param year: Year of the dates to check whether they are missing. :return: List of missing dates. """ - main_logger.info(f"Checking if all dates were scraped in {self.sqlite_name}...") + main_logger.info(f"Checking if all dates were scraped in a database...") missing_date_list: list[str] = [] session = self.Session() @@ -167,8 +175,7 @@ def find_missing_dates_in_db(self, year: int) -> list[str]: if not count_of_date_by_mth_as_of_today: today = datetime.datetime.now(datetime.timezone.utc).date() - main_logger.warning(f"No scraped data for today, {today}, UTC time for city {self.city} in" - f" {self.sqlite_name}.") + main_logger.warning(f"No scraped data for today, {today}, UTC time for city {self.city} in a database") return missing_date_list today = datetime.datetime.today() @@ -255,8 +262,6 @@ def parse_arguments() -> argparse.Namespace: :return: argparse.Namespace """ parser = argparse.ArgumentParser(description='Parser which controls Missing Date Checker.') - parser.add_argument('--sqlite_name', type=str, default='avg_japan_hotel_price_test.db', - help='SQLite database path, default is "avg_japan_hotel_price_test.db"') parser.add_argument('--city', type=str, help='City where the hotels are located', required=True) parser.add_argument('--group_adults', type=int, default=1, help='Number of Adults, default is 1') parser.add_argument('--num_rooms', type=int, default=1, help='Number of Rooms, default is 1') @@ -274,11 +279,12 @@ def parse_arguments() -> argparse.Namespace: args = parse_arguments() booking_details = BookingDetails(city=args.city, group_adults=args.group_adults, - num_rooms=args.num_rooms, group_children=args.group_children, - selected_currency=args.selected_currency, - scrape_only_hotel=args.scrape_only_hotel, sqlite_name=args.sqlite_name) + num_rooms=args.num_rooms, group_children=args.group_children, + selected_currency=args.selected_currency, + scrape_only_hotel=args.scrape_only_hotel) - db_path: str = args.sqlite_name - missing_date_checker = MissingDateChecker(sqlite_name=db_path, city=args.city) + postgres_url = f"postgresql://{postgres_user}:{postgres_password}@{postgres_host}:{postgres_port}/{postgres_db}" + engine = create_engine(postgres_url) + missing_date_checker = MissingDateChecker(engine=engine, city=args.city) missing_dates: list[str] = missing_date_checker.find_missing_dates_in_db(year=args.year) - asyncio.run(scrape_missing_dates(missing_dates, booking_details_class=booking_details)) + asyncio.run(scrape_missing_dates(missing_dates, booking_details_class=booking_details, engine=engine)) diff --git a/docker-compose.yml b/docker-compose.yml index 1cbfc5b..fbf90b3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -6,11 +6,13 @@ services: container_name: jp_scraper hostname: jp_scraper environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: postgres + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres + - POSTGRES_HOST=localhost + - POSTGRES_PORT=5500 + - POSTGRES_DB=postgres volumes: - - ./postgres_data:/var/lib/postgresql/data + - ${POSTGRES_DATA_PATH}:/var/lib/postgresql/data ports: - "5500:5432" networks: diff --git a/docs/SCRAPER_ARGS.md b/docs/SCRAPER_ARGS.md index ff1ec3b..47a2588 100644 --- a/docs/SCRAPER_ARGS.md +++ b/docs/SCRAPER_ARGS.md @@ -53,10 +53,6 @@ - **Type**: `bool` - **Description**: If set to `True`, the scraper will only target hotel properties. -### `--sqlite_name` -- **Type**: `str` -- **Description**: The name of the SQLite database file to use. Only used for Basic and Whole-Month Scraper. - ### `--year` - **Type**: `int` - **Description**: Specifies the year to scrape. This argument is required for Whole-Month Scraper. diff --git a/japan_avg_hotel_price_finder/booking_details.py b/japan_avg_hotel_price_finder/booking_details.py index 35edf69..5e54aab 100644 --- a/japan_avg_hotel_price_finder/booking_details.py +++ b/japan_avg_hotel_price_finder/booking_details.py @@ -15,7 +15,6 @@ class BookingDetails(BaseModel): - group_children (int): Number of children. - selected_currency (str): Room price currency. - scrape_only_hotel (bool): Whether to scrape only hotel. - - sqlite_name (str): Path to SQLite database. """ city: str = '' country: str = '' @@ -25,5 +24,4 @@ class BookingDetails(BaseModel): num_rooms: int = Field(1, gt=0) group_children: int = Field(0, ge=0) selected_currency: str = '' - scrape_only_hotel: bool = True - sqlite_name: str = '' \ No newline at end of file + scrape_only_hotel: bool = True \ No newline at end of file diff --git a/japan_avg_hotel_price_finder/graphql_scraper.py b/japan_avg_hotel_price_finder/graphql_scraper.py index 16552a3..f38e7f6 100644 --- a/japan_avg_hotel_price_finder/graphql_scraper.py +++ b/japan_avg_hotel_price_finder/graphql_scraper.py @@ -38,11 +38,7 @@ class BasicGraphQLScraper(BaseModel): group_children (str): Number of children, default is 0. selected_currency (str): Currency of the room price, default is USD. scrape_only_hotel (bool): Whether to scrape only the hotel property data, default is True - sqlite_name (str): Name of SQLite database to store the scraped data. """ - # Set SQLite database name - sqlite_name: str - # Set booking details. city: str country: str diff --git a/japan_avg_hotel_price_finder/japan_hotel_scraper.py b/japan_avg_hotel_price_finder/japan_hotel_scraper.py index cc9c1e6..5f2f37f 100644 --- a/japan_avg_hotel_price_finder/japan_hotel_scraper.py +++ b/japan_avg_hotel_price_finder/japan_hotel_scraper.py @@ -2,8 +2,8 @@ from typing import Any import pandas as pd -from pydantic import Field -from sqlalchemy import create_engine +from pydantic import Field, ConfigDict +from sqlalchemy import Engine from sqlalchemy.orm import sessionmaker from japan_avg_hotel_price_finder.configure_logging import main_logger @@ -32,8 +32,11 @@ class JapanScraper(WholeMonthGraphQLScraper): region (str): The current region being scraped. start_month (int): Month to start scraping (1-12). end_month (int): Last month to scrape (1-12). - sqlite_name (str): Path and name of SQLite database to store the scraped data. + engine (Engine): SQLAlchemy engine. """ + engine: Engine + + model_config = ConfigDict(arbitrary_types_allowed=True) japan_regions: dict[str, list[str]] = { "Hokkaido": ["Hokkaido"], @@ -114,17 +117,17 @@ async def _scrape_whole_year(self) -> None: df = await self.scrape_whole_month() if not df.empty: df['Region'] = self.region - self._load_to_sqlite(df) + self._load_to_database(df) else: main_logger.warning(f"No data found for {self.city} for {calendar.month_name[self.month]} {self.year}") - def _load_to_sqlite(self, prefecture_hotel_data: pd.DataFrame) -> None: + def _load_to_database(self, prefecture_hotel_data: pd.DataFrame) -> None: """ - Load hotel data of all Japan Prefectures to SQLite using SQLAlchemy ORM + Load hotel data of all Japan Prefectures to a database using SQLAlchemy ORM :param prefecture_hotel_data: DataFrame with the whole-year hotel data of the given prefecture. :return: None """ - main_logger.info(f"Loading hotel data to SQLite {self.sqlite_name}...") + main_logger.info(f"Loading hotel data to database...") # Rename 'City' column to 'Prefecture' prefecture_hotel_data = prefecture_hotel_data.rename(columns={'City': 'Prefecture'}) @@ -132,9 +135,10 @@ def _load_to_sqlite(self, prefecture_hotel_data: pd.DataFrame) -> None: # Rename Price/Review column prefecture_hotel_data.rename(columns={'Price/Review': 'PriceReview'}, inplace=True) - engine = create_engine(f'sqlite:///{self.sqlite_name}') - Base.metadata.tables['JapanHotels'].create(engine, checkfirst=True) - Session = sessionmaker(bind=engine) + # Create all tables + Base.metadata.create_all(self.engine) + + Session = sessionmaker(bind=self.engine) session = Session() try: @@ -148,7 +152,7 @@ def _load_to_sqlite(self, prefecture_hotel_data: pd.DataFrame) -> None: session.bulk_save_objects(hotel_prices) session.commit() - main_logger.info(f"Hotel data for {self.city} loaded to SQLite successfully.") + main_logger.info(f"Hotel data for {self.city} loaded to a database successfully.") except Exception as e: session.rollback() main_logger.error(f"An error occurred while saving data: {str(e)}") diff --git a/japan_avg_hotel_price_finder/main_argparse.py b/japan_avg_hotel_price_finder/main_argparse.py index 2600637..75a9309 100644 --- a/japan_avg_hotel_price_finder/main_argparse.py +++ b/japan_avg_hotel_price_finder/main_argparse.py @@ -32,16 +32,6 @@ def add_booking_details_arguments(parser: argparse.ArgumentParser) -> None: parser.add_argument('--scrape_only_hotel', action='store_true', help='Whether to scrape only hotel properties') -def add_database_arguments(parser: argparse.ArgumentParser) -> None: - """ - Add database-related arguments to the parser. - :param parser: argparse.ArgumentParser - :return: None - """ - db_group = parser.add_mutually_exclusive_group(required=True) - db_group.add_argument('--sqlite_name', type=str, help='SQLite database path') - - def add_date_arguments(parser: argparse.ArgumentParser) -> None: """ Add date and length of stay arguments to the parser. @@ -101,7 +91,6 @@ def parse_arguments() -> argparse.Namespace: parser = argparse.ArgumentParser(description='Parser that controls which kind of scraper to use.') add_scraper_arguments(parser) add_booking_details_arguments(parser) - add_database_arguments(parser) add_date_arguments(parser) add_japan_arguments(parser) args = parser.parse_args() diff --git a/japan_avg_hotel_price_finder/sql/migrate_to_sqlite.py b/japan_avg_hotel_price_finder/sql/migrate_to_sqlite.py deleted file mode 100644 index 1ba0631..0000000 --- a/japan_avg_hotel_price_finder/sql/migrate_to_sqlite.py +++ /dev/null @@ -1,245 +0,0 @@ -import pandas as pd -from sqlalchemy import create_engine, func, case, MetaData -from sqlalchemy.orm import sessionmaker, Session - -from japan_avg_hotel_price_finder.configure_logging import main_logger -from japan_avg_hotel_price_finder.sql.db_model import Base, HotelPrice, AverageRoomPriceByDate, \ - AverageHotelRoomPriceByReview, AverageHotelRoomPriceByDayOfWeek, AverageHotelRoomPriceByMonth, \ - AverageHotelRoomPriceByLocation - - -def migrate_data_to_sqlite(df_filtered: pd.DataFrame, db: str) -> None: - """ - Migrate hotel data to sqlite database using SQLAlchemy ORM. - :param df_filtered: pandas dataframe. - :param db: SQLite database path. - :return: None - """ - main_logger.info('Connecting to SQLite database (or create it if it doesn\'t exist)...') - - engine = create_engine(f'sqlite:///{db}') - - # Create a new MetaData instance - metadata = MetaData() - - # Copy all tables from Base.metadata except JapanHotels - for table_name, table in Base.metadata.tables.items(): - if table_name != 'JapanHotels': - table.tometadata(metadata) - - # Create all tables in the new metadata - metadata.create_all(engine) - - Session = sessionmaker(bind=engine) - session = Session() - - try: - # Rename Price/Review column - df_filtered.rename(columns={'Price/Review': 'PriceReview'}, inplace=True) - - # Convert DataFrame to list of dictionaries - records = df_filtered.to_dict('records') - - # Bulk insert records - session.bulk_insert_mappings(HotelPrice, records) - - create_avg_hotel_room_price_by_date_table(session) - create_avg_room_price_by_review_table(session) - create_avg_hotel_price_by_dow_table(session) - create_avg_hotel_price_by_month_table(session) - create_avg_room_price_by_location(session) - - session.commit() - main_logger.info(f'Data has been saved to {db}') - except Exception as e: - session.rollback() - main_logger.error(f"An unexpected error occurred: {str(e)}") - main_logger.error("Database changes have been rolled back.") - raise - finally: - session.close() - - -def create_avg_hotel_room_price_by_date_table(session: Session) -> None: - """ - Create AverageHotelRoomPriceByDate table using SQLAlchemy ORM - :param session: SQLAlchemy session - :return: None - """ - main_logger.info('Create AverageRoomPriceByDate table...') - - # Clear existing data - session.query(AverageRoomPriceByDate).delete() - - # Insert new data - avg_prices = session.query( - HotelPrice.Date, - func.avg(HotelPrice.Price).label('AveragePrice'), - HotelPrice.City - ).group_by(HotelPrice.Date, HotelPrice.City).all() - - new_records = [ - AverageRoomPriceByDate(Date=date, AveragePrice=avg_price, City=city) - for date, avg_price, city in avg_prices - ] - - session.bulk_save_objects(new_records) - session.commit() - - -def create_avg_room_price_by_review_table(session: Session) -> None: - """ - Create AverageHotelRoomPriceByReview table using SQLAlchemy ORM. - :param session: SQLAlchemy session - :return: None - """ - main_logger.info("Create AverageHotelRoomPriceByReview table...") - - # Clear existing data - session.query(AverageHotelRoomPriceByReview).delete() - - # Calculate average prices by review - avg_prices = session.query( - HotelPrice.Review, - func.avg(HotelPrice.Price).label('AveragePrice') - ).group_by(HotelPrice.Review).all() - - # Create new records - new_records = [ - AverageHotelRoomPriceByReview(Review=review, AveragePrice=avg_price) - for review, avg_price in avg_prices - ] - - # Bulk insert new records - session.bulk_save_objects(new_records) - session.commit() - - -def create_avg_hotel_price_by_dow_table(session: Session) -> None: - """ - Create AverageHotelRoomPriceByDayOfWeek table using SQLAlchemy ORM. - :param session: SQLAlchemy session - :return: None - """ - main_logger.info("Create AverageHotelRoomPriceByDayOfWeek table...") - - # Clear existing data - session.query(AverageHotelRoomPriceByDayOfWeek).delete() - - # Calculate average prices by day of week - day_of_week_case = case( - (func.strftime('%w', HotelPrice.Date) == '0', 'Sunday'), - (func.strftime('%w', HotelPrice.Date) == '1', 'Monday'), - (func.strftime('%w', HotelPrice.Date) == '2', 'Tuesday'), - (func.strftime('%w', HotelPrice.Date) == '3', 'Wednesday'), - (func.strftime('%w', HotelPrice.Date) == '4', 'Thursday'), - (func.strftime('%w', HotelPrice.Date) == '5', 'Friday'), - (func.strftime('%w', HotelPrice.Date) == '6', 'Saturday'), - ).label('day_of_week') - - avg_prices = session.query( - day_of_week_case, - func.avg(HotelPrice.Price).label('avg_price') - ).group_by(day_of_week_case).all() - - # Create new records - new_records = [ - AverageHotelRoomPriceByDayOfWeek(DayOfWeek=day_of_week, AveragePrice=avg_price) - for day_of_week, avg_price in avg_prices - ] - - # Bulk insert new records - session.bulk_save_objects(new_records) - session.commit() - - -def create_avg_hotel_price_by_month_table(session: Session) -> None: - """ - Create AverageHotelRoomPriceByMonth table using SQLAlchemy ORM. - :param session: SQLAlchemy session - :return: None - """ - main_logger.info("Create AverageHotelRoomPriceByMonth table...") - - # Clear existing data - session.query(AverageHotelRoomPriceByMonth).delete() - - # Define the month case - month_case = case( - (func.strftime('%m', HotelPrice.Date) == '01', 'January'), - (func.strftime('%m', HotelPrice.Date) == '02', 'February'), - (func.strftime('%m', HotelPrice.Date) == '03', 'March'), - (func.strftime('%m', HotelPrice.Date) == '04', 'April'), - (func.strftime('%m', HotelPrice.Date) == '05', 'May'), - (func.strftime('%m', HotelPrice.Date) == '06', 'June'), - (func.strftime('%m', HotelPrice.Date) == '07', 'July'), - (func.strftime('%m', HotelPrice.Date) == '08', 'August'), - (func.strftime('%m', HotelPrice.Date) == '09', 'September'), - (func.strftime('%m', HotelPrice.Date) == '10', 'October'), - (func.strftime('%m', HotelPrice.Date) == '11', 'November'), - (func.strftime('%m', HotelPrice.Date) == '12', 'December'), - ).label('month') - - # Define the quarter case - quarter_case = case( - (func.strftime('%m', HotelPrice.Date).in_(['01', '02', '03']), 'Quarter1'), - (func.strftime('%m', HotelPrice.Date).in_(['04', '05', '06']), 'Quarter2'), - (func.strftime('%m', HotelPrice.Date).in_(['07', '08', '09']), 'Quarter3'), - (func.strftime('%m', HotelPrice.Date).in_(['10', '11', '12']), 'Quarter4'), - ).label('quarter') - - # Calculate average prices by month - avg_prices = session.query( - month_case, - func.avg(HotelPrice.Price).label('avg_price'), - quarter_case - ).group_by(month_case).all() - - # Create new records - new_records = [ - AverageHotelRoomPriceByMonth(Month=month, AveragePrice=avg_price, Quarter=quarter) - for month, avg_price, quarter in avg_prices - ] - - # Bulk insert new records - session.bulk_save_objects(new_records) - session.commit() - - -def create_avg_room_price_by_location(session: Session) -> None: - """ - Create AverageHotelRoomPriceByLocation table using SQLAlchemy ORM. - :param session: SQLAlchemy session - :return: None - """ - main_logger.info("Create AverageHotelRoomPriceByLocation table...") - - # Clear existing data - session.query(AverageHotelRoomPriceByLocation).delete() - - # Calculate average prices, ratings, and price per review by location - avg_data = session.query( - HotelPrice.Location, - func.avg(HotelPrice.Price).label('AveragePrice'), - func.avg(HotelPrice.Review).label('AverageRating'), - func.avg(HotelPrice.PriceReview).label('AveragePricePerReview') - ).group_by(HotelPrice.Location).all() - - # Create new records - new_records = [ - AverageHotelRoomPriceByLocation( - Location=location, - AveragePrice=avg_price, - AverageRating=avg_rating, - AveragePricePerReview=avg_price_per_review - ) - for location, avg_price, avg_rating, avg_price_per_review in avg_data - ] - - # Bulk insert new records - session.bulk_save_objects(new_records) - session.commit() - - -if __name__ == '__main__': - pass diff --git a/japan_avg_hotel_price_finder/sql/save_to_db.py b/japan_avg_hotel_price_finder/sql/save_to_db.py index 3a1ccb3..dcb6cd1 100644 --- a/japan_avg_hotel_price_finder/sql/save_to_db.py +++ b/japan_avg_hotel_price_finder/sql/save_to_db.py @@ -1,19 +1,271 @@ import pandas as pd +from sqlalchemy import create_engine, MetaData, func, case, Engine, Float, extract, Integer, cast +from sqlalchemy.dialects import sqlite, postgresql +from sqlalchemy.orm import sessionmaker, Session from japan_avg_hotel_price_finder.configure_logging import main_logger -from japan_avg_hotel_price_finder.sql.migrate_to_sqlite import migrate_data_to_sqlite +from japan_avg_hotel_price_finder.sql.db_model import Base, HotelPrice, AverageRoomPriceByDate, \ + AverageHotelRoomPriceByReview, AverageHotelRoomPriceByDayOfWeek, AverageHotelRoomPriceByMonth, \ + AverageHotelRoomPriceByLocation -def save_scraped_data(dataframe: pd.DataFrame, db: str) -> None: +def save_scraped_data(dataframe: pd.DataFrame, engine: Engine) -> None: """ - Save scraped data to SQLite database. + Save scraped data to a database. :param dataframe: Pandas DataFrame. - :param db: SQLite database path. + :param engine: SQLAlchemy engine. :return: None """ main_logger.info("Saving scraped data...") if not dataframe.empty: - main_logger.info(f'Save data to SQLite database: {db}') - migrate_data_to_sqlite(dataframe, db) + main_logger.info(f'Save data to a database') + migrate_data_to_database(dataframe, engine) else: - main_logger.warning('The dataframe is empty. No data to save') \ No newline at end of file + main_logger.warning('The dataframe is empty. No data to save') + + +def migrate_data_to_database(df_filtered: pd.DataFrame, engine: Engine) -> None: + """ + Migrate hotel data to a database using SQLAlchemy ORM. + :param df_filtered: pandas dataframe. + :param engine: SQLAlchemy engine. + :return: None + """ + main_logger.info('Connecting to a database (or create it if it doesn\'t exist)...') + + # Create all tables + Base.metadata.create_all(engine) + + Session = sessionmaker(bind=engine) + session = Session() + + try: + # Rename Price/Review column + df_filtered.rename(columns={'Price/Review': 'PriceReview'}, inplace=True) + + # Convert DataFrame to list of dictionaries + records = df_filtered.to_dict('records') + + # Bulk insert records + session.bulk_insert_mappings(HotelPrice, records) + + create_avg_hotel_room_price_by_date_table(session) + create_avg_room_price_by_review_table(session) + create_avg_hotel_price_by_dow_table(session) + create_avg_hotel_price_by_month_table(session) + create_avg_room_price_by_location(session) + + session.commit() + main_logger.info(f'Data has been saved to a database successfully.') + except Exception as e: + session.rollback() + main_logger.error(f"An unexpected error occurred: {str(e)}") + main_logger.error("Database changes have been rolled back.") + raise + finally: + session.close() + + +def create_avg_hotel_room_price_by_date_table(session: Session) -> None: + """ + Create AverageHotelRoomPriceByDate table using SQLAlchemy ORM + :param session: SQLAlchemy session + :return: None + """ + main_logger.info('Create AverageRoomPriceByDate table...') + + # Clear existing data + session.query(AverageRoomPriceByDate).delete() + + # Insert new data + avg_prices = session.query( + HotelPrice.Date, + func.avg(HotelPrice.Price).label('AveragePrice'), + HotelPrice.City + ).group_by(HotelPrice.Date, HotelPrice.City).all() + + new_records = [ + AverageRoomPriceByDate(Date=date, AveragePrice=avg_price, City=city) + for date, avg_price, city in avg_prices + ] + + session.bulk_save_objects(new_records) + session.commit() + + +def create_avg_room_price_by_review_table(session: Session) -> None: + """ + Create AverageHotelRoomPriceByReview table using SQLAlchemy ORM. + :param session: SQLAlchemy session + :return: None + """ + main_logger.info("Create AverageHotelRoomPriceByReview table...") + + # Clear existing data + session.query(AverageHotelRoomPriceByReview).delete() + + # Calculate average prices by review, rounding review to nearest integer + avg_prices = session.query( + func.round(HotelPrice.Review), + func.avg(HotelPrice.Price).label('AveragePrice') + ).group_by(func.round(HotelPrice.Review)).all() + + # Create new records + new_records = [ + AverageHotelRoomPriceByReview(Review=review, AveragePrice=avg_price) + for review, avg_price in avg_prices + ] + + # Bulk insert new records + session.bulk_save_objects(new_records) + session.commit() + + +def create_avg_hotel_price_by_dow_table(session: Session) -> None: + """ + Create AverageHotelRoomPriceByDayOfWeek table using SQLAlchemy ORM. + :param session: SQLAlchemy session + :return: None + """ + main_logger.info("Create AverageHotelRoomPriceByDayOfWeek table...") + + # Clear existing data + session.query(AverageHotelRoomPriceByDayOfWeek).delete() + + # Detect database dialect + dialect = session.bind.dialect + + if isinstance(dialect, postgresql.dialect): + # PostgreSQL specific date extraction + dow_func = extract('dow', func.to_date(HotelPrice.Date, 'YYYY-MM-DD')) + elif isinstance(dialect, sqlite.dialect): + # SQLite specific date extraction + dow_func = func.cast(func.strftime('%w', func.date(HotelPrice.Date)), Integer) + else: + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + # Calculate average prices by day of week + day_of_week_case = case( + (dow_func == 0, 'Sunday'), + (dow_func == 1, 'Monday'), + (dow_func == 2, 'Tuesday'), + (dow_func == 3, 'Wednesday'), + (dow_func == 4, 'Thursday'), + (dow_func == 5, 'Friday'), + (dow_func == 6, 'Saturday'), + ).label('day_of_week') + + avg_prices = session.query( + day_of_week_case, + func.avg(HotelPrice.Price).label('avg_price') + ).group_by(day_of_week_case).all() + + # Create new records + new_records = [ + AverageHotelRoomPriceByDayOfWeek(DayOfWeek=day_of_week, AveragePrice=avg_price) + for day_of_week, avg_price in avg_prices + ] + + # Bulk insert new records + session.bulk_save_objects(new_records) + session.commit() + + +def create_avg_hotel_price_by_month_table(session: Session) -> None: + """ + Create AverageHotelRoomPriceByMonth table using SQLAlchemy ORM. + :param session: SQLAlchemy session + :return: None + """ + main_logger.info("Create AverageHotelRoomPriceByMonth table...") + + # Clear existing data + session.query(AverageHotelRoomPriceByMonth).delete() + + # Detect database dialect + dialect = session.bind.dialect + + if isinstance(dialect, postgresql.dialect): + # PostgreSQL specific date extraction + month_func = extract('month', func.to_date(HotelPrice.Date, 'YYYY-MM-DD')) + elif isinstance(dialect, sqlite.dialect): + # SQLite specific date extraction + month_func = cast(func.strftime('%m', HotelPrice.Date), Integer) + else: + raise NotImplementedError(f"Unsupported dialect: {dialect}") + + # Define the month case + month_case = case( + (month_func == 1, 'January'), + (month_func == 2, 'February'), + (month_func == 3, 'March'), + (month_func == 4, 'April'), + (month_func == 5, 'May'), + (month_func == 6, 'June'), + (month_func == 7, 'July'), + (month_func == 8, 'August'), + (month_func == 9, 'September'), + (month_func == 10, 'October'), + (month_func == 11, 'November'), + (month_func == 12, 'December'), + ).label('month') + + # Define the quarter case + quarter_case = case( + (month_func.in_([1, 2, 3]), 'Quarter1'), + (month_func.in_([4, 5, 6]), 'Quarter2'), + (month_func.in_([7, 8, 9]), 'Quarter3'), + (month_func.in_([10, 11, 12]), 'Quarter4'), + ).label('quarter') + + # Calculate average prices by month + avg_prices = session.query( + month_case, + func.avg(HotelPrice.Price).label('avg_price'), + quarter_case + ).group_by(month_case, quarter_case).all() + + # Create new records + new_records = [ + AverageHotelRoomPriceByMonth(Month=month, AveragePrice=avg_price, Quarter=quarter) + for month, avg_price, quarter in avg_prices + ] + + # Bulk insert new records + session.bulk_save_objects(new_records) + session.commit() + + +def create_avg_room_price_by_location(session: Session) -> None: + """ + Create AverageHotelRoomPriceByLocation table using SQLAlchemy ORM. + :param session: SQLAlchemy session + :return: None + """ + main_logger.info("Create AverageHotelRoomPriceByLocation table...") + + # Clear existing data + session.query(AverageHotelRoomPriceByLocation).delete() + + # Calculate average prices, ratings, and price per review by location + avg_data = session.query( + HotelPrice.Location, + func.avg(HotelPrice.Price).label('AveragePrice'), + func.avg(HotelPrice.Review).label('AverageRating'), + func.avg(HotelPrice.PriceReview).label('AveragePricePerReview') + ).group_by(HotelPrice.Location).all() + + # Create new records + new_records = [ + AverageHotelRoomPriceByLocation( + Location=location, + AveragePrice=avg_price, + AverageRating=avg_rating, + AveragePricePerReview=avg_price_per_review + ) + for location, avg_price, avg_rating, avg_price_per_review in avg_data + ] + + # Bulk insert new records + session.bulk_save_objects(new_records) + session.commit() \ No newline at end of file diff --git a/japan_avg_hotel_price_finder/whole_mth_graphql_scraper.py b/japan_avg_hotel_price_finder/whole_mth_graphql_scraper.py index 3ce65a7..71f17b7 100644 --- a/japan_avg_hotel_price_finder/whole_mth_graphql_scraper.py +++ b/japan_avg_hotel_price_finder/whole_mth_graphql_scraper.py @@ -29,7 +29,6 @@ class WholeMonthGraphQLScraper(BasicGraphQLScraper): nights (int): Number of nights (Length of stay) which defines the room price. For example, nights = 1 means scraping the hotel with room price for 1 night. Default is 1. - sqlite_name (str): Path and name of SQLite database to store the scraped data. """ # Set the start day, month, year, and length of stay year: int = Field(datetime.datetime.now().year, gt=0) diff --git a/main.py b/main.py index d3b38cc..6b46d1d 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,9 @@ import argparse import asyncio +import os + +from dotenv import load_dotenv +from sqlalchemy import Engine, create_engine from japan_avg_hotel_price_finder.configure_logging import main_logger from japan_avg_hotel_price_finder.graphql_scraper import BasicGraphQLScraper @@ -8,6 +12,8 @@ from japan_avg_hotel_price_finder.sql.save_to_db import save_scraped_data from japan_avg_hotel_price_finder.whole_mth_graphql_scraper import WholeMonthGraphQLScraper +load_dotenv() + def validate_required_args(arguments: argparse.Namespace, required_args: list[str]) -> bool: """ @@ -23,29 +29,31 @@ def validate_required_args(arguments: argparse.Namespace, required_args: list[st return True -def run_whole_month_scraper(arguments: argparse.Namespace) -> None: +def run_whole_month_scraper(arguments: argparse.Namespace, engine: Engine) -> None: """ Run the Whole-Month GraphQL scraper :param arguments: Arguments to pass to the scraper + :param engine: SQLAlchemy engine :return: None """ required_args = ['year', 'month', 'city', 'country'] if validate_required_args(arguments, required_args): scraper = WholeMonthGraphQLScraper( city=arguments.city, year=arguments.year, month=arguments.month, start_day=arguments.start_day, - nights=arguments.nights, scrape_only_hotel=arguments.scrape_only_hotel, sqlite_name=arguments.sqlite_name, + nights=arguments.nights, scrape_only_hotel=arguments.scrape_only_hotel, selected_currency=arguments.selected_currency, group_adults=arguments.group_adults, num_rooms=arguments.num_rooms, group_children=arguments.group_children, check_in='', check_out='', country=arguments.country ) df = asyncio.run(scraper.scrape_whole_month()) - save_scraped_data(dataframe=df, db=scraper.sqlite_name) + save_scraped_data(dataframe=df, engine=engine) -def run_japan_hotel_scraper(arguments: argparse.Namespace) -> None: +def run_japan_hotel_scraper(arguments: argparse.Namespace, engine: Engine) -> None: """ Run the Japan hotel scraper :param arguments: Arguments to pass to the scraper + :param engine: SQLAlchemy engine :return: None """ if arguments.prefecture: @@ -57,30 +65,30 @@ def run_japan_hotel_scraper(arguments: argparse.Namespace) -> None: month: int = 1 scraper = JapanScraper( city=city, year=year, month=month, start_day=arguments.start_day, nights=arguments.nights, - scrape_only_hotel=arguments.scrape_only_hotel, sqlite_name=arguments.sqlite_name, - selected_currency=arguments.selected_currency, group_adults=arguments.group_adults, - num_rooms=arguments.num_rooms, group_children=arguments.group_children, check_in='', check_out='', - country=arguments.country + scrape_only_hotel=arguments.scrape_only_hotel, selected_currency=arguments.selected_currency, + group_adults=arguments.group_adults, num_rooms=arguments.num_rooms, group_children=arguments.group_children, + check_in='', check_out='', country=arguments.country, engine=engine ) asyncio.run(scraper.scrape_japan_hotels()) -def run_basic_scraper(arguments: argparse.Namespace) -> None: +def run_basic_scraper(arguments: argparse.Namespace, engine: Engine) -> None: """ Run the Basic GraphQL scraper :param arguments: Arguments to pass to the scraper + :param engine: SQLAlchemy engine :return: None """ required_args = ['check_in', 'check_out', 'city', 'country'] if validate_required_args(arguments, required_args): scraper = BasicGraphQLScraper( - city=arguments.city, scrape_only_hotel=arguments.scrape_only_hotel, sqlite_name=arguments.sqlite_name, + city=arguments.city, scrape_only_hotel=arguments.scrape_only_hotel, selected_currency=arguments.selected_currency, group_adults=arguments.group_adults, num_rooms=arguments.num_rooms, group_children=arguments.group_children, check_in=arguments.check_in, check_out=arguments.check_out, country=arguments.country ) df = asyncio.run(scraper.scrape_graphql()) - save_scraped_data(dataframe=df, db=scraper.sqlite_name) + save_scraped_data(dataframe=df, engine=engine) def main() -> None: @@ -89,12 +97,16 @@ def main() -> None: :return: None """ arguments = parse_arguments() + postgres_url = (f"postgresql://{os.getenv('POSTGRES_USER')}:{os.getenv('POSTGRES_PASSWORD')}" + f"@{os.getenv('POSTGRES_HOST')}:{os.getenv('POSTGRES_PORT')}/{os.getenv('POSTGRES_DB')}") + engine = create_engine(postgres_url) + if arguments.whole_mth: - run_whole_month_scraper(arguments) + run_whole_month_scraper(arguments, engine) elif arguments.japan_hotel: - run_japan_hotel_scraper(arguments) + run_japan_hotel_scraper(arguments, engine) else: - run_basic_scraper(arguments) + run_basic_scraper(arguments, engine) if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 80f4704..e028111 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ python-dotenv~=1.0.1 pytest-mock~=3.14.0 pydantic~=2.9.2 ruff~=0.7.1 -SQLAlchemy~=2.0.36 \ No newline at end of file +SQLAlchemy~=2.0.36 +psycopg2~=2.9.10 \ No newline at end of file diff --git a/tests/test_basic_graphql_scraper/test_graphql_scraper.py b/tests/test_basic_graphql_scraper/test_graphql_scraper.py index 0aafc58..818cd85 100644 --- a/tests/test_basic_graphql_scraper/test_graphql_scraper.py +++ b/tests/test_basic_graphql_scraper/test_graphql_scraper.py @@ -16,7 +16,7 @@ async def test_graphql_scraper(): check_in = today.strftime('%Y-%m-%d') tomorrow = today + datetime.timedelta(days=1) check_out = tomorrow.strftime('%Y-%m-%d') - scraper = BasicGraphQLScraper(city='Osaka', num_rooms=1, group_adults=1, group_children=0, sqlite_name='', check_out=check_out, + scraper = BasicGraphQLScraper(city='Osaka', num_rooms=1, group_adults=1, group_children=0, check_out=check_out, check_in=check_in, selected_currency='USD', scrape_only_hotel=True, country=country) df = await scraper.scrape_graphql() @@ -33,7 +33,7 @@ async def test_graphql_scraper_incorrect_date(): check_in = today.strftime('%Y-%m-%d') yesterday = today - datetime.timedelta(days=1) check_out = yesterday.strftime('%Y-%m-%d') - scraper = BasicGraphQLScraper(city='Osaka', num_rooms=1, group_adults=1, group_children=0, sqlite_name='', check_out=check_out, + scraper = BasicGraphQLScraper(city='Osaka', num_rooms=1, group_adults=1, group_children=0, check_out=check_out, check_in=check_in, selected_currency='USD', scrape_only_hotel=True, country=country) with pytest.raises(pydantic.ValidationError): diff --git a/tests/test_japan_scraper/test_japan_scraper.py b/tests/test_japan_scraper/test_japan_scraper.py index c657ce5..3f904a8 100644 --- a/tests/test_japan_scraper/test_japan_scraper.py +++ b/tests/test_japan_scraper/test_japan_scraper.py @@ -1,20 +1,118 @@ -import asyncio import datetime -import sqlite3 +from unittest.mock import patch, AsyncMock + +import pytest +from sqlalchemy import create_engine, text, Table, Column, Integer, String, Date, Boolean, MetaData +import pandas as pd from japan_avg_hotel_price_finder.japan_hotel_scraper import JapanScraper -def test_japan_scraper(tmp_path): +@pytest.mark.asyncio +async def test_japan_scraper(tmp_path): db = str(tmp_path / 'test_japan_scraper.db') + engine = create_engine(f'sqlite:///{db}') + + # Create the table explicitly + metadata = MetaData() + Table('JapanHotels', metadata, + Column('id', Integer, primary_key=True), + Column('Region', String), + Column('Prefecture', String), + Column('hotel_name', String), + Column('price', Integer), + Column('date', Date), + Column('check_in', String), + Column('check_out', String), + Column('group_adults', Integer), + Column('num_rooms', Integer), + Column('group_children', Integer), + Column('selected_currency', String), + Column('scrape_only_hotel', Boolean), + Column('country', String), + Column('city', String) + ) + metadata.create_all(engine) - scraper = JapanScraper(sqlite_name=db, country='Japan', city='', check_in='', check_out='') + scraper = JapanScraper( + engine=engine, + country='Japan', + city='', + check_in='', + check_out='', + group_adults=1, + num_rooms=1, + group_children=0, + selected_currency='USD', + scrape_only_hotel=True + ) scraper.japan_regions = {"Hokkaido": ["Hokkaido"]} current_month = datetime.datetime.now().month scraper.start_month = current_month scraper.end_month = current_month - asyncio.run(scraper.scrape_japan_hotels()) - with sqlite3.connect(db) as conn: - res = conn.execute('SELECT * FROM JapanHotels').fetchall() - assert len(res) > 1 + # Create sample data + sample_data = pd.DataFrame({ + 'Region': ['Hokkaido', 'Hokkaido'], + 'Prefecture': ['Hokkaido', 'Hokkaido'], + 'hotel_name': ['Hotel A', 'Hotel B'], + 'price': [100, 200], + 'date': [datetime.date(2023, current_month, 1), datetime.date(2023, current_month, 2)], + 'check_in': ['2023-11-01', '2023-11-02'], + 'check_out': ['2023-11-02', '2023-11-03'], + 'group_adults': [1, 1], + 'num_rooms': [1, 1], + 'group_children': [0, 0], + 'selected_currency': ['USD', 'USD'], + 'scrape_only_hotel': [True, True], + 'country': ['Japan', 'Japan'], + 'city': ['Hokkaido', 'Hokkaido'] + }) + + # Insert sample data directly into the database + sample_data.to_sql('JapanHotels', engine, if_exists='append', index=False) + + # Mock the _scrape_whole_year method to return our sample data + async def mock_scrape_whole_year(): + return sample_data + + with patch.object(JapanScraper, '_scrape_whole_year', + new=AsyncMock(side_effect=mock_scrape_whole_year)) as mock_scrape: + await scraper.scrape_japan_hotels() + + # Check if the mock was called + assert mock_scrape.called, "_scrape_whole_year was not called" + mock_scrape.assert_called_once() + + # Verify data in the database + with engine.connect() as conn: + # Check if the table exists + result = conn.execute(text("SELECT name FROM sqlite_master WHERE type='table' AND name='JapanHotels'")) + assert result.fetchone() is not None, "JapanHotels table does not exist" + + # Check the number of rows + result = conn.execute(text("SELECT COUNT(*) FROM JapanHotels")) + count = result.scalar() + assert count == 2, f"Expected 2 rows in the database, but found {count}" + + # Check the content of the rows + result = conn.execute(text("SELECT * FROM JapanHotels")) + rows = result.fetchall() + + for i, row in enumerate(rows): + assert row.Region == 'Hokkaido', f"Row {i}: Region mismatch" + assert row.Prefecture == 'Hokkaido', f"Row {i}: Prefecture mismatch" + assert row.hotel_name == f'Hotel {"A" if i == 0 else "B"}', f"Row {i}: hotel_name mismatch" + assert row.price == (100 if i == 0 else 200), f"Row {i}: price mismatch" + assert row.date == f'2023-11-0{i + 1}', f"Row {i}: date mismatch" + assert row.check_in == f'2023-11-0{i + 1}', f"Row {i}: check_in mismatch" + assert row.check_out == f'2023-11-0{i + 2}', f"Row {i}: check_out mismatch" + assert row.group_adults == 1, f"Row {i}: group_adults mismatch" + assert row.num_rooms == 1, f"Row {i}: num_rooms mismatch" + assert row.group_children == 0, f"Row {i}: group_children mismatch" + assert row.selected_currency == 'USD', f"Row {i}: selected_currency mismatch" + assert row.scrape_only_hotel == 1, f"Row {i}: scrape_only_hotel mismatch" # SQLite stores booleans as 0 or 1 + assert row.country == 'Japan', f"Row {i}: country mismatch" + assert row.city == 'Hokkaido', f"Row {i}: city mismatch" + + print("Test completed successfully!") \ No newline at end of file diff --git a/tests/test_main_argparse/test_parse_arguments_main.py b/tests/test_main_argparse/test_parse_arguments_main.py index 1d840d2..41fe808 100644 --- a/tests/test_main_argparse/test_parse_arguments_main.py +++ b/tests/test_main_argparse/test_parse_arguments_main.py @@ -17,7 +17,6 @@ def test_parse_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -33,7 +32,6 @@ def test_parse_arguments(monkeypatch): assert args.group_children == 0 assert args.selected_currency == "USD" assert args.scrape_only_hotel is True - assert args.sqlite_name == "test.db" def test_missing_required_arguments(monkeypatch): @@ -49,7 +47,6 @@ def test_missing_required_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -71,7 +68,6 @@ def test_invalid_argument_types(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -93,7 +89,6 @@ def test_boundary_values(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -115,7 +110,6 @@ def test_valid_japan_scraper_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -131,7 +125,6 @@ def test_valid_japan_scraper_arguments(monkeypatch): assert args.group_children == 0 assert args.selected_currency == "USD" assert args.scrape_only_hotel is True - assert args.sqlite_name == "test.db" def test_valid_whole_month_scraper_arguments(monkeypatch): @@ -148,7 +141,6 @@ def test_valid_whole_month_scraper_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db", "--year", "2024", "--month", "1" ] @@ -166,7 +158,6 @@ def test_valid_whole_month_scraper_arguments(monkeypatch): assert args.group_children == 0 assert args.selected_currency == "USD" assert args.scrape_only_hotel is True - assert args.sqlite_name == "test.db" assert args.year == 2024 assert args.month == 1 @@ -186,7 +177,6 @@ def test_japan_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", - "--sqlite_name", "test.db" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -208,5 +198,4 @@ def test_japan_arguments(monkeypatch): assert args.num_rooms == 1 assert args.group_children == 0 assert args.selected_currency == "USD" - assert args.scrape_only_hotel is True - assert args.sqlite_name == "test.db" + assert args.scrape_only_hotel is True \ No newline at end of file diff --git a/tests/test_migrate_to_database/test_migrate_data_to_database.py b/tests/test_migrate_to_database/test_migrate_data_to_database.py new file mode 100644 index 0000000..6512160 --- /dev/null +++ b/tests/test_migrate_to_database/test_migrate_data_to_database.py @@ -0,0 +1,74 @@ +from unittest.mock import patch + +import pandas as pd +import pytest +from sqlalchemy import create_engine, inspect +from sqlalchemy.orm import sessionmaker + +from japan_avg_hotel_price_finder.sql.save_to_db import migrate_data_to_database +from japan_avg_hotel_price_finder.sql.db_model import Base, HotelPrice + + +@pytest.fixture +def sqlite_engine(tmp_path): + db = tmp_path / 'test_successful_connection_to_sqlite.db' + engine = create_engine(f'sqlite:///{db}') + Base.metadata.create_all(engine) + return engine + + +@pytest.fixture +def db_session(sqlite_engine): + Session = sessionmaker(bind=sqlite_engine) + return Session() + + +def test_successful_connection_to_sqlite(sqlite_engine, db_session): + # Given + df_filtered = pd.DataFrame({ + 'Hotel': ['Hotel A', 'Hotel B'], + 'Price': [100, 150], + 'Review': [4.5, 3.8], + 'Price/Review': [22.2, 39.5], + 'Location': ['San Francisco', 'San Francisco'], + 'City': ['City X', 'City Y'], + 'Date': ['2022-01-01', '2022-01-02'], + 'AsOf': [pd.Timestamp('2022-01-01'), pd.Timestamp('2022-01-02')] + }) + + # When + migrate_data_to_database(df_filtered, sqlite_engine) + + # Then + inspector = inspect(sqlite_engine) + assert 'HotelPrice' in inspector.get_table_names() + + result = db_session.query(HotelPrice).all() + assert len(result) > 0 + + +@patch('japan_avg_hotel_price_finder.sql.save_to_db.create_avg_hotel_price_by_dow_table') +@patch('japan_avg_hotel_price_finder.sql.save_to_db.create_avg_hotel_room_price_by_date_table') +@patch('japan_avg_hotel_price_finder.sql.save_to_db.create_avg_room_price_by_review_table') +@patch('japan_avg_hotel_price_finder.sql.save_to_db.create_avg_hotel_price_by_month_table') +@patch('japan_avg_hotel_price_finder.sql.save_to_db.create_avg_room_price_by_location') +def test_handle_empty_dataframe(mock_location, mock_month, mock_review, mock_date, mock_dow, sqlite_engine, db_session): + # Given + df_filtered = pd.DataFrame(columns=['Hotel', 'Price', 'Review', 'Location', 'Price/Review', 'City', 'Date', 'AsOf']) + + # When + migrate_data_to_database(df_filtered, sqlite_engine) + + # Then + inspector = inspect(sqlite_engine) + assert 'HotelPrice' in inspector.get_table_names() + + result = db_session.query(HotelPrice).all() + assert len(result) == 0 + + # Assert that all the aggregation functions were called + mock_dow.assert_called_once() + mock_date.assert_called_once() + mock_review.assert_called_once() + mock_month.assert_called_once() + mock_location.assert_called_once() \ No newline at end of file diff --git a/tests/test_migrate_to_sqlite/test_migrate_data_to_sqlite.py b/tests/test_migrate_to_sqlite/test_migrate_data_to_sqlite.py deleted file mode 100644 index 9797b5c..0000000 --- a/tests/test_migrate_to_sqlite/test_migrate_data_to_sqlite.py +++ /dev/null @@ -1,46 +0,0 @@ -import sqlite3 - -import pandas as pd - -from japan_avg_hotel_price_finder.sql.migrate_to_sqlite import migrate_data_to_sqlite - - -def test_successful_connection_to_sqlite(tmp_path): - # Given - df_filtered = pd.DataFrame({ - 'Hotel': ['Hotel A', 'Hotel B'], - 'Price': [100, 150], - 'Review': [4.5, 3.8], - 'Price/Review': [22.2, 39.5], - 'Location': ['San Francisco', 'San Francisco'], - 'City': ['City X', 'City Y'], - 'Date': ['2022-01-01', '2022-01-02'], - 'AsOf': [pd.Timestamp('2022-01-01'), pd.Timestamp('2022-01-02')] - }) - db = tmp_path / 'test_successful_connection_to_sqlite.db' - - # When - migrate_data_to_sqlite(df_filtered, str(db)) - - # Then - with sqlite3.connect(db) as con: - result = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='HotelPrice';").fetchone() - assert result is not None - result = con.execute("SELECT * FROM HotelPrice;").fetchall() - assert len(result) > 0 - - -def test_handle_empty_dataframe(tmp_path): - # Given - df_filtered = pd.DataFrame(columns=['Hotel', 'Price', 'Review', 'Location', 'Price/Review', 'City', 'Date', 'AsOf']) - db = tmp_path / 'test_handle_empty_dataframe.db' - - # When - migrate_data_to_sqlite(df_filtered, str(db)) - - # Then - with sqlite3.connect(db) as con: - result = con.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='HotelPrice';").fetchone() - assert result is not None - result = con.execute("SELECT * FROM HotelPrice;").fetchall() - assert len(result) == 0 \ No newline at end of file diff --git a/tests/test_missing_date_checker/test_check_missing_dates.py b/tests/test_missing_date_checker/test_check_missing_dates.py index 4c55f25..efd2b7b 100644 --- a/tests/test_missing_date_checker/test_check_missing_dates.py +++ b/tests/test_missing_date_checker/test_check_missing_dates.py @@ -1,13 +1,15 @@ import datetime import pytest from unittest.mock import MagicMock +from sqlalchemy import create_engine from check_missing_dates import MissingDateChecker @pytest.fixture def missing_date_checker(): - return MissingDateChecker(sqlite_name='test_check_missing_dates.db', city='Tokyo') + engine = create_engine('sqlite:///test_check_missing_dates.db') + return MissingDateChecker(engine=engine, city='Tokyo') def test_check_missing_dates_all_dates_scraped_current_month(missing_date_checker): @@ -16,7 +18,6 @@ def test_check_missing_dates_all_dates_scraped_current_month(missing_date_checke missing_date_list = [] today = datetime.datetime(2023, 10, 1) year = 2023 - missing_date_checker.check_missing_dates(count_of_date_by_mth_asof_today_list, current_month, missing_date_list, today, year) @@ -29,19 +30,18 @@ def test_check_missing_dates_all_dates_scraped_future_month(missing_date_checker missing_date_list = [] today = datetime.datetime(2023, 9, 1) year = 2023 - missing_date_checker.check_missing_dates(count_of_date_by_mth_asof_today_list, current_month, missing_date_list, today, year) assert len(missing_date_list) == 0 + def test_check_missing_dates_all_dates_scraped_past_month(missing_date_checker): count_of_date_by_mth_asof_today_list = [('2023-10', 31)] current_month = '2023-10' missing_date_list = [] today = datetime.datetime(2023, 11, 1) year = 2023 - missing_date_checker.check_missing_dates(count_of_date_by_mth_asof_today_list, current_month, missing_date_list, today, year) @@ -54,7 +54,6 @@ def test_check_missing_dates_some_dates_missing(missing_date_checker): missing_date_list = [] today = datetime.datetime(2023, 10, 1) year = 2023 - missing_date_checker.find_dates_of_the_month_in_db = MagicMock( return_value=({'2023-10-01', '2023-10-02', '2023-10-03'}, '2023-10-31', '2023-10-01') ) @@ -70,8 +69,39 @@ def test_check_missing_dates_no_data(missing_date_checker): missing_date_list = [] today = datetime.datetime(2023, 10, 31) year = 2023 + missing_date_checker.check_missing_dates(count_of_date_by_mth_asof_today_list, current_month, missing_date_list, + today, year) + + assert len(missing_date_list) == 0 + + +def test_check_missing_dates_some_dates_missing_multiple_months(missing_date_checker): + count_of_date_by_mth_asof_today_list = [('2023-10', 28), ('2023-11', 25)] + current_month = '2023-11' + missing_date_list = [] + today = datetime.datetime(2023, 10, 15) + year = 2023 + + def mock_find_dates(days_in_month, month, year): + if month == 10: + return ({'2023-10-01', '2023-10-02', '2023-10-03'}, '2023-10-31', '2023-10-01') + elif month == 11: + return ({'2023-11-01', '2023-11-02', '2023-11-03', '2023-11-04', '2023-11-05'}, '2023-11-30', '2023-11-01') + + missing_date_checker.find_dates_of_the_month_in_db = MagicMock(side_effect=mock_find_dates) missing_date_checker.check_missing_dates(count_of_date_by_mth_asof_today_list, current_month, missing_date_list, today, year) - assert len(missing_date_list) == 0 \ No newline at end of file + assert len(missing_date_list) > 0 + assert any(date.startswith('2023-10-') for date in missing_date_list) + assert any(date.startswith('2023-11-') for date in missing_date_list) + + october_missing = [date for date in missing_date_list if date.startswith('2023-10-')] + november_missing = [date for date in missing_date_list if date.startswith('2023-11-')] + + assert len(october_missing) == 17 # From 2023-10-15 to 2023-10-31 + assert set(october_missing) == set([f'2023-10-{i:02d}' for i in range(15, 32)]) + + assert len(november_missing) == 25 # From 2023-11-06 to 2023-11-30 + assert set(november_missing) == set([f'2023-11-{i:02d}' for i in range(6, 31)]) \ No newline at end of file diff --git a/tests/test_missing_date_checker/test_find_missing_dates_in_db.py b/tests/test_missing_date_checker/test_find_missing_dates_in_db.py index 09465ea..66d5024 100644 --- a/tests/test_missing_date_checker/test_find_missing_dates_in_db.py +++ b/tests/test_missing_date_checker/test_find_missing_dates_in_db.py @@ -2,7 +2,8 @@ from unittest.mock import MagicMock, patch import pytest -from sqlalchemy.orm import Session +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker, Session from check_missing_dates import MissingDateChecker @@ -14,7 +15,9 @@ def mock_today(): @pytest.fixture def missing_date_checker(mock_today): - return MissingDateChecker(sqlite_name='test.db', city='TestCity') + engine = create_engine('sqlite:///test.db') + Session = sessionmaker(bind=engine) + return MissingDateChecker(engine=engine, city='TestCity') @pytest.fixture @@ -97,7 +100,6 @@ def side_effect(*args, **kwargs): result = missing_date_checker.find_missing_dates_in_db(mock_today.year) assert len(result) == 17 # 6 missing in December + 11 missing in January - # Check if all dates are either in the current month/year, next month/year, or the month after assert all( date.startswith(f"{mock_today.year}-{mock_today.month:02d}-") or @@ -140,7 +142,4 @@ def test_find_missing_dates_in_db_special_dates(missing_date_checker, mock_sessi with patch('check_missing_dates.MissingDateChecker.check_missing_dates') as mock_check: mock_check.side_effect = lambda *args, **kwargs: args[2].append(mock_today.strftime('%Y-%m-%d')) - result = missing_date_checker.find_missing_dates_in_db(mock_today.year) - - assert len(result) == 1 - assert result[0] == mock_today.strftime('%Y-%m-%d') + result = missing_date_checker.find_missing_dates_in_db \ No newline at end of file diff --git a/tests/test_missing_date_checker/test_parse_arguments_missing_date_checker.py b/tests/test_missing_date_checker/test_parse_arguments_missing_date_checker.py index c94b3e9..c73bc83 100644 --- a/tests/test_missing_date_checker/test_parse_arguments_missing_date_checker.py +++ b/tests/test_missing_date_checker/test_parse_arguments_missing_date_checker.py @@ -12,7 +12,6 @@ def test_valid_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", "True", - "--sqlite_name", "test.db", "--year", "2024" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -24,7 +23,6 @@ def test_valid_arguments(monkeypatch): assert args.group_children == 0 assert args.selected_currency == "USD" assert args.scrape_only_hotel is True - assert args.sqlite_name == "test.db" assert args.year == 2024 @@ -36,7 +34,6 @@ def test_missing_required_arguments(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", "True", - "--sqlite_name", "test.db", "--year", "2024" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -53,7 +50,6 @@ def test_invalid_argument_types(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", "True", - "--sqlite_name", "test.db", "--year", "2024" ] monkeypatch.setattr(sys, 'argv', test_args) @@ -70,7 +66,6 @@ def test_boundary_values(monkeypatch): "--group_children", "0", "--selected_currency", "USD", "--scrape_only_hotel", "True", - "--sqlite_name", "test.db", "--year", "2024" ] monkeypatch.setattr(sys, 'argv', test_args) diff --git a/tests/test_missing_date_checker/test_scrape_missing_dates.py b/tests/test_missing_date_checker/test_scrape_missing_dates.py index db1d255..be9c1a4 100644 --- a/tests/test_missing_date_checker/test_scrape_missing_dates.py +++ b/tests/test_missing_date_checker/test_scrape_missing_dates.py @@ -1,11 +1,12 @@ import datetime import pytest -from sqlalchemy import create_engine, func +from sqlalchemy import func, create_engine from sqlalchemy.orm import sessionmaker -from check_missing_dates import scrape_missing_dates, BookingDetails -from japan_avg_hotel_price_finder.sql.db_model import Base, HotelPrice +from check_missing_dates import scrape_missing_dates +from japan_avg_hotel_price_finder.booking_details import BookingDetails +from japan_avg_hotel_price_finder.sql.db_model import HotelPrice, Base @pytest.mark.asyncio @@ -17,7 +18,7 @@ async def test_scrape_missing_dates(tmp_path) -> None: Session = sessionmaker(bind=engine) booking_details_param = BookingDetails(city='Osaka', group_adults=1, num_rooms=1, group_children=0, - selected_currency='USD', scrape_only_hotel=True, sqlite_name=str(db_file)) + selected_currency='USD', scrape_only_hotel=True) today = datetime.datetime.today() if today.month == 12: @@ -34,27 +35,29 @@ async def test_scrape_missing_dates(tmp_path) -> None: third_missing_date = f'{year}-{month_str}-20' missing_dates = [first_missing_date, second_missing_date, third_missing_date] - await scrape_missing_dates(missing_dates_list=missing_dates, booking_details_class=booking_details_param) + await scrape_missing_dates(missing_dates_list=missing_dates, booking_details_class=booking_details_param, + engine=engine) session = Session() try: + # Get the AsOf date from the first record + asof_date = session.query(func.date(HotelPrice.AsOf)).first()[0] + result = ( session.query(func.strftime('%Y-%m', HotelPrice.Date).label('month'), func.count(func.distinct(HotelPrice.Date)).label('count')) .filter(HotelPrice.City == 'Osaka') - .filter(func.date(HotelPrice.AsOf) == func.date('now')) + .filter(func.date(HotelPrice.AsOf) == asof_date) .group_by(func.strftime('%Y-%m', HotelPrice.Date)) .all() ) - assert len(result) == 1 # We expect only one month - assert result[0].count == 3 # We expect 3 dates to be scraped + print(f"Query result: {result}") + + assert len(result) == 1, f"Expected 1 result, but got {len(result)}" + assert result[0].count == 3, f"Expected 3 dates, but got {result[0].count if result else 'no results'}" finally: session.close() # Clean up: drop all tables - Base.metadata.drop_all(engine) - - -if __name__ == '__main__': - pytest.main() \ No newline at end of file + Base.metadata.drop_all(engine) \ No newline at end of file diff --git a/tests/test_sql/test_save_scraped_data.py b/tests/test_sql/test_save_scraped_data.py index 5220159..3135285 100644 --- a/tests/test_sql/test_save_scraped_data.py +++ b/tests/test_sql/test_save_scraped_data.py @@ -1,7 +1,8 @@ -from unittest.mock import patch +from unittest.mock import patch, Mock import pandas as pd import pytest +from sqlalchemy import Engine from japan_avg_hotel_price_finder.sql.save_to_db import save_scraped_data @@ -17,18 +18,22 @@ def sample_dataframe(): def empty_dataframe(): return pd.DataFrame() +@pytest.fixture +def mock_engine(): + return Mock(spec=Engine) + @patch('japan_avg_hotel_price_finder.sql.save_to_db.main_logger') -@patch('japan_avg_hotel_price_finder.sql.save_to_db.migrate_data_to_sqlite') -def test_save_scraped_data_non_empty(mock_migrate, mock_logger, sample_dataframe, tmp_path): - db_path = str(tmp_path / 'test_db.sqlite') - save_scraped_data(sample_dataframe, db_path) - mock_logger.info.assert_called_with(f'Save data to SQLite database: {db_path}') - mock_migrate.assert_called_once_with(sample_dataframe, db_path) +@patch('japan_avg_hotel_price_finder.sql.save_to_db.migrate_data_to_database') +def test_save_scraped_data_non_empty(mock_migrate, mock_logger, sample_dataframe, mock_engine): + save_scraped_data(sample_dataframe, mock_engine) + mock_logger.info.assert_any_call("Saving scraped data...") + mock_logger.info.assert_any_call('Save data to a database') + mock_migrate.assert_called_once_with(sample_dataframe, mock_engine) @patch('japan_avg_hotel_price_finder.sql.save_to_db.main_logger') -@patch('japan_avg_hotel_price_finder.sql.save_to_db.migrate_data_to_sqlite') -def test_save_scraped_data_empty(mock_migrate, mock_logger, empty_dataframe, tmp_path): - db_path = str(tmp_path / 'test_db.sqlite') - save_scraped_data(empty_dataframe, db_path) +@patch('japan_avg_hotel_price_finder.sql.save_to_db.migrate_data_to_database') +def test_save_scraped_data_empty(mock_migrate, mock_logger, empty_dataframe, mock_engine): + save_scraped_data(empty_dataframe, mock_engine) + mock_logger.info.assert_called_with("Saving scraped data...") mock_logger.warning.assert_called_with('The dataframe is empty. No data to save') mock_migrate.assert_not_called() \ No newline at end of file diff --git a/tests/test_whole_mth_scraper/test_whole_month_graphql_scraper.py b/tests/test_whole_mth_scraper/test_whole_month_graphql_scraper.py index 42c6929..ee012cc 100644 --- a/tests/test_whole_mth_scraper/test_whole_month_graphql_scraper.py +++ b/tests/test_whole_mth_scraper/test_whole_month_graphql_scraper.py @@ -12,7 +12,7 @@ async def test_whole_month_graphql_scraper(): current_year = today.year scraper = WholeMonthGraphQLScraper(month=current_mth, city='Osaka', year=current_year, start_day=1, nights=1, - num_rooms=1, group_adults=1, group_children=0, sqlite_name='', + num_rooms=1, group_adults=1, group_children=0, selected_currency='USD', scrape_only_hotel=True, check_in='', check_out='', country='Japan') df = await scraper.scrape_whole_month() @@ -29,7 +29,7 @@ async def test_whole_month_graphql_scraper_past_date(): current_year = today.year scraper = WholeMonthGraphQLScraper(month=past_mth, city='Osaka', year=current_year, start_day=1, nights=1, - num_rooms=1, group_adults=1, group_children=0, sqlite_name='', + num_rooms=1, group_adults=1, group_children=0, selected_currency='USD', scrape_only_hotel=True, check_in='', check_out='', country='Japan') df = await scraper.scrape_whole_month()