From 4c94f2f5d27cea130937ed2811f532a2909a303c Mon Sep 17 00:00:00 2001 From: Daniel Silva Date: Sat, 6 Jan 2024 13:14:07 +0000 Subject: [PATCH] feat: limit list of IDs size to 500 in get_papers() --- semanticscholar/AsyncSemanticScholar.py | 6 +++++- semanticscholar/SemanticScholar.py | 2 +- tests/test_semanticscholar.py | 18 ++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/semanticscholar/AsyncSemanticScholar.py b/semanticscholar/AsyncSemanticScholar.py index f98ad0d..4b06863 100644 --- a/semanticscholar/AsyncSemanticScholar.py +++ b/semanticscholar/AsyncSemanticScholar.py @@ -128,7 +128,7 @@ async def get_papers( :calls: `POST /paper/batch `_ - :param str paper_ids: list of IDs (must be <= 1000) - S2PaperId,\ + :param str paper_ids: list of IDs (must be <= 500) - S2PaperId,\ CorpusId, DOI, ArXivId, MAG, ACL, PMID, PMCID, or URL from: - semanticscholar.org @@ -143,6 +143,10 @@ async def get_papers( :raises: BadQueryParametersException: if no paper was found. ''' + if len(paper_ids) > 500 or len(paper_ids) == 0: + raise ValueError( + 'The paper_ids parameter must be a list of 1 to 500 IDs.') + if not fields: fields = Paper.SEARCH_FIELDS diff --git a/semanticscholar/SemanticScholar.py b/semanticscholar/SemanticScholar.py index 48a9f26..fb42e07 100644 --- a/semanticscholar/SemanticScholar.py +++ b/semanticscholar/SemanticScholar.py @@ -112,7 +112,7 @@ def get_papers( :calls: `POST /paper/batch `_ - :param str paper_ids: list of IDs (must be <= 1000) - S2PaperId,\ + :param str paper_ids: list of IDs (must be <= 500) - S2PaperId,\ CorpusId, DOI, ArXivId, MAG, ACL, PMID, PMCID, or URL from: - semanticscholar.org diff --git a/tests/test_semanticscholar.py b/tests/test_semanticscholar.py index d41c340..5500259 100644 --- a/tests/test_semanticscholar.py +++ b/tests/test_semanticscholar.py @@ -181,6 +181,14 @@ def test_get_papers(self): self.assertIn( 'E. Duflo', [author.name for author in item.authors]) + def test_get_papers_list_size_exceeded(self): + list_of_paper_ids = [str(i) for i in range(501)] + self.assertRaises(ValueError, self.sch.get_papers, list_of_paper_ids) + + def test_get_papers_list_empty(self): + list_of_paper_ids = [] + self.assertRaises(ValueError, self.sch.get_papers, list_of_paper_ids) + @test_vcr.use_cassette def test_get_paper_authors(self): data = self.sch.get_paper_authors('10.2139/ssrn.2250500') @@ -498,6 +506,16 @@ async def test_get_papers_async(self): with self.subTest(subtest=item.paperId): self.assertIn( 'E. Duflo', [author.name for author in item.authors]) + + async def test_get_papers_list_size_exceeded_async(self): + list_of_paper_ids = [str(i) for i in range(501)] + with self.assertRaises(ValueError): + await self.sch.get_papers(list_of_paper_ids) + + async def test_get_papers_list_empty_async(self): + list_of_paper_ids = [] + with self.assertRaises(ValueError): + await self.sch.get_papers(list_of_paper_ids) @test_vcr.use_cassette async def test_get_paper_authors_async(self):