Skip to content

Commit

Permalink
Merge branch 'main' into avoid-loading-weights-before-recipe-application
Browse files Browse the repository at this point in the history
  • Loading branch information
rahul-tuli authored Apr 8, 2024
2 parents 6323062 + d636d35 commit f88904c
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 3 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"GPUtil>=1.4.0",
"protobuf>=3.12.2,<=3.20.3",
"click>=7.1.2,!=8.0.0", # latest version < 8.0 + blocked version with reported bug
"clearml==1.14.4",
]
_nm_deps = [f"{'sparsezoo' if is_release else 'sparsezoo-nightly'}~={version_nm_deps}"]
_deepsparse_deps = [
Expand Down
105 changes: 105 additions & 0 deletions src/sparseml/pytorch/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,21 @@
wandb = None
wandb_err = err


try:
from clearml import Task

clearml_err = None
except Exception as err:
clearml = None
clearml_err = err

from sparseml.utils import ALL_TOKEN, create_dirs


__all__ = [
"BaseLogger",
"ClearMLLogger",
"LambdaLogger",
"PythonLogger",
"TensorBoardLogger",
Expand Down Expand Up @@ -628,6 +638,101 @@ def save(
return True


class ClearMLLogger(LambdaLogger):
@staticmethod
def available() -> bool:
"""
:return: True if wandb is available and installed, False, otherwise
"""
return not clearml_err

def __init__(
self,
name: str = "clearml",
enabled: bool = True,
project_name: str = "sparseml",
task_name: str = "",
):
if task_name == "":
now = datetime.now()
task_name = now.strftime("%d-%m-%Y_%H.%M.%S")

self.task = Task.init(project_name=project_name, task_name=task_name)

super().__init__(
lambda_func=self.log_scalar,
name=name,
enabled=enabled,
)

def log_hyperparams(
self,
params: Dict,
level: Optional[int] = None,
) -> bool:
"""
:param params: Each key-value pair in the dictionary is the name of the
hyper parameter and it's corresponding value.
:return: True if logged, False otherwise.
"""
if not self.enabled:
return False

self.task.connect(params)
return True

def log_scalar(
self,
tag: str,
value: float,
step: Optional[int] = None,
wall_time: Optional[float] = None,
level: Optional[int] = None,
) -> bool:
"""
:param tag: identifying tag to log the value with
:param value: value to save
:param step: global step for when the value was taken
:param wall_time: global wall time for when the value was taken,
defaults to time.time()
:param kwargs: additional logging arguments to support Python and custom loggers
:return: True if logged, False otherwise.
"""
logger = self.task.get_logger()
# each series is superimposed on the same plot on title
logger.report_scalar(
title=tag, series=str(level) or tag, value=value, iteration=step
)
return True

def log_scalars(
self,
tag: str,
values: Dict[str, float],
step: Optional[int] = None,
wall_time: Optional[float] = None,
level: Optional[int] = None,
) -> bool:
"""
:param tag: identifying tag to log the values with
:param values: values to save
:param step: global step for when the values were taken
:param wall_time: global wall time for when the values were taken,
defaults to time.time()
:param kwargs: additional logging arguments to support Python and custom loggers
:return: True if logged, False otherwise.
"""
for k, v in values.items():
self.log_scalar(
tag=f"{tag}.{k}",
value=v,
step=step,
wall_time=wall_time,
level=level,
)
return True


class SparsificationGroupLogger(BaseLogger):
"""
Modifier logger that handles outputting values to other supported systems.
Expand Down
8 changes: 5 additions & 3 deletions tests/sparseml/pytorch/utils/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import pytest

from sparseml.pytorch.utils import (
ClearMLLogger,
LambdaLogger,
LoggerManager,
PythonLogger,
Expand All @@ -45,6 +46,7 @@
or True
),
*([WANDBLogger()] if WANDBLogger.available() else []),
*([ClearMLLogger()] if ClearMLLogger.available() else []),
SparsificationGroupLogger(
lambda_func=lambda tag, value, values, step, wall_time, level: logging.info(
f"{tag}, {value}, {values}, {step}, {wall_time}, {level}"
Expand Down Expand Up @@ -79,12 +81,12 @@ def test_log_scalar(self, logger):

def test_log_scalars(self, logger):
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0})
logger.log_scalars("test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 1)
logger.log_scalars("test-scalars-tag2", {"scalar1": 0.0, "scalar2": 1.0}, 1)
logger.log_scalars(
"test-scalars-tag", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
"test-scalars-tag3", {"scalar1": 0.0, "scalar2": 1.0}, 2, time.time() - 1
)
logger.log_scalars(
"test-scalars-tag",
"test-scalars-tag4",
{"scalar1": 0.0, "scalar2": 1.0},
2,
time.time() - 1,
Expand Down
63 changes: 63 additions & 0 deletions tests/sparseml/test_clear_ml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

from clearml import Task
from sparseml.transformers import apply
from sparseml.utils import is_package_available


is_torch_available = is_package_available("torch")
if is_torch_available:
import torch

torch_err = None
else:
torch = object
torch_err = ModuleNotFoundError(
"`torch` is not installed, use `pip install torch` to log to Weights and Biases"
)


def test_oneshot_and_finetune(tmp_path: Path):
recipe_str = "tests/sparseml/transformers/finetune/test_alternate_recipe.yaml"
model = "Xenova/llama2.c-stories15M"
device = "cuda:0"
if is_torch_available and not torch.cuda.is_available():
device = "cpu"
dataset = "wikitext"
dataset_config_name = "wikitext-2-raw-v1"
concatenate_data = True
run_stages = True
output_dir = tmp_path
max_steps = 50
splits = {"train": "train[:50%]", "calibration": "train[50%:60%]"}

# clearML will automatically log default capturing entries without
# explicitly calling logger. Logs accessible in https://app.clear.ml/
Task.init(project_name="test", task_name="test_oneshot_and_finetune")

apply(
model=model,
dataset=dataset,
dataset_config_name=dataset_config_name,
run_stages=run_stages,
output_dir=output_dir,
recipe=recipe_str,
max_steps=max_steps,
concatenate_data=concatenate_data,
splits=splits,
oneshot_device=device,
)

0 comments on commit f88904c

Please sign in to comment.