Skip to content

Commit

Permalink
Split classes and add reading the current validator from environment …
Browse files Browse the repository at this point in the history
…variable
  • Loading branch information
goanpeca committed Sep 24, 2020
1 parent 2b339e7 commit 7b4274c
Show file tree
Hide file tree
Showing 7 changed files with 242 additions and 164 deletions.
2 changes: 0 additions & 2 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ install:
- pip freeze
script:
- py.test -v --cov nbformat nbformat
- pip uninstall fastjsonschema --yes
- py.test -v --cov nbformat nbformat
after_success:
- codecov
matrix:
Expand Down
2 changes: 0 additions & 2 deletions appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ build: off
# to run your custom scripts instead of automatic tests
test_script:
- 'py.test -v --cov nbformat nbformat'
- 'pip uninstall fastjsonschema --yes'
- 'py.test -v --cov nbformat nbformat'

on_success:
- codecov
4 changes: 2 additions & 2 deletions nbformat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def reads(s, as_version, **kwargs):
if as_version is not NO_CONVERT:
nb = convert(nb, as_version)
try:
validate(nb, use_fast=True)
validate(nb)
except ValidationError as e:
get_logger().error("Notebook JSON is invalid: %s", e)
return nb
Expand Down Expand Up @@ -104,7 +104,7 @@ def writes(nb, version=NO_CONVERT, **kwargs):
else:
version, _ = reader.get_version(nb)
try:
validate(nb, use_fast=True)
validate(nb)
except ValidationError as e:
get_logger().error("Notebook JSON is invalid: %s", e)
return versions[version].writes_json(nb, **kwargs)
Expand Down
71 changes: 49 additions & 22 deletions nbformat/json_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
libraries.
"""

import os

import jsonschema
from jsonschema import Draft4Validator as _JsonSchemaValidator
from jsonschema import ValidationError

Expand All @@ -16,32 +19,56 @@
_JsonSchemaException = ValidationError


class Validator:
"""
Common validator wrapper to provide a uniform usage of other schema validation
libraries.
"""
class JsonSchemaValidator:
name = "jsonschema"

def __init__(self, schema):
self._schema = schema

# Validation libraries
self._jsonschema = _JsonSchemaValidator(schema) # Default
self._fastjsonschema_validate = fastjsonschema.compile(schema) if fastjsonschema else None
self._default_validator = _JsonSchemaValidator(schema) # Default
self._validator = self._default_validator

def validate(self, data):
"""
Validate the schema of ``data``.
Will use ``fastjsonschema`` if available.
"""
if fastjsonschema:
try:
self._fastjsonschema_validate(data)
except _JsonSchemaException as e:
raise ValidationError(e.message)
else:
self._jsonschema.validate(data)
self._default_validator.validate(data)

def iter_errors(self, data, schema=None):
return self._jsonschema.iter_errors(data, schema)
return self._default_validator.iter_errors(data, schema)


class FastJsonSchemaValidator(JsonSchemaValidator):
name = "fastjsonschema"

def __init__(self, schema):
super().__init__(schema)

self._validator = fastjsonschema.compile(schema)

def validate(self, data):
try:
self._validator(data)
except _JsonSchemaException as error:
raise ValidationError(error.message, schema_path=error.path)


_VALIDATOR_MAP = [
("fastjsonschema", fastjsonschema, FastJsonSchemaValidator),
("jsonschema", jsonschema, JsonSchemaValidator),
]
VALIDATORS = [item[0] for item in _VALIDATOR_MAP]


def _validator_for_name(validator_name):
if validator_name not in VALIDATORS:
raise ValueError("Invalid validator '{0}' value!\nValid values are: {1}".format(
validator_name, VALIDATORS))

for (name, module, validator_cls) in _VALIDATOR_MAP:
if module and validator_name == name:
return validator_cls


def get_current_validator():
"""
Return the current validator based on the value of an environment variable.
"""
validator_name = os.environ.get("NBFORMAT_VALIDATOR", "jsonschema")
return _validator_for_name(validator_name)
9 changes: 5 additions & 4 deletions nbformat/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
class TestsBase(unittest.TestCase):
"""Base tests class."""

def fopen(self, f, mode=u'r',encoding='utf-8'):
return io.open(os.path.join(self._get_files_path(), f), mode, encoding=encoding)
@classmethod
def fopen(cls, f, mode=u'r',encoding='utf-8'):
return io.open(os.path.join(cls._get_files_path(), f), mode, encoding=encoding)


def _get_files_path(self):
@classmethod
def _get_files_path(cls):
return os.path.dirname(__file__)
Loading

0 comments on commit 7b4274c

Please sign in to comment.