From 0cc3d9443629e1603d86be59904ae45bca5dee0d Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Fri, 11 Feb 2022 14:41:38 +1100 Subject: [PATCH 01/35] Set up basis for JASS searcher --- pyserini/search/jass/__main__.py | 278 ++++++++++++++++++++++++++++++ pyserini/search/jass/_searcher.py | 170 ++++++++++++++++++ 2 files changed, 448 insertions(+) create mode 100644 pyserini/search/jass/__main__.py create mode 100644 pyserini/search/jass/_searcher.py diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py new file mode 100644 index 000000000..e9bf79fdf --- /dev/null +++ b/pyserini/search/jass/__main__.py @@ -0,0 +1,278 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import argparse +import os + +from tqdm import tqdm +from transformers import AutoTokenizer + +from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer +from pyserini.output_writer import OutputFormat, get_output_writer +from pyserini.pyclass import autoclass +from pyserini.query_iterator import get_query_iterator, TopicsFormat +from pyserini.search import ImpactSearcher, SimpleSearcher, JDisjunctionMaxQueryGenerator +from pyserini.search.lucene.reranker import ClassifierType, PseudoRelevanceClassifierReranker + + +def set_bm25_parameters(searcher, index, k1=None, b=None): + if k1 is not None or b is not None: + if k1 is None or b is None: + print('Must set *both* k1 and b for BM25!') + exit() + print(f'Setting BM25 parameters: k1={k1}, b={b}') + searcher.set_bm25(k1, b) + else: + # Automatically set bm25 parameters based on known index: + if index == 'msmarco-passage' or index == 'msmarco-passage-slim': + print('MS MARCO passage: setting k1=0.82, b=0.68') + searcher.set_bm25(0.82, 0.68) + elif index == 'msmarco-passage-expanded': + print('MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86') + searcher.set_bm25(2.18, 0.86) + elif index == 'msmarco-doc' or index == 'msmarco-doc-slim': + print('MS MARCO doc: setting k1=4.46, b=0.82') + searcher.set_bm25(4.46, 0.82) + elif index == 'msmarco-doc-per-passage' or index == 'msmarco-doc-per-passage-slim': + print('MS MARCO doc, per passage: setting k1=2.16, b=0.61') + searcher.set_bm25(2.16, 0.61) + elif index == 'msmarco-doc-expanded-per-doc': + print('MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87') + searcher.set_bm25(4.68, 0.87) + elif index == 'msmarco-doc-expanded-per-passage': + print('MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59') + searcher.set_bm25(2.56, 0.59) + + +def define_search_args(parser): + parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, + help="Path to Lucene index or name of prebuilt index.") + + parser.add_argument('--impact', action='store_true', help="Use Impact.") + parser.add_argument('--encoder', type=str, default=None, help="encoder name") + parser.add_argument('--min-idf', type=int, default=0, help="minimum idf") + + parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).") + parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') + parser.add_argument('--b', type=float, help='BM25 b parameter.') + + parser.add_argument('--rm3', action='store_true', help="Use RM3") + parser.add_argument('--qld', action='store_true', help="Use QLD") + + parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en') + + parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[], + help='Specify the classifier PseudoRelevanceClassifierReranker uses.') + parser.add_argument('--prcl.vectorizer', dest='vectorizer', type=str, + help='Type of vectorizer. Available: TfidfVectorizer, BM25Vectorizer.') + parser.add_argument('--prcl.r', dest='r', type=int, default=10, + help='Number of positive labels in pseudo relevance feedback.') + parser.add_argument('--prcl.n', dest='n', type=int, default=100, + help='Number of negative labels in pseudo relevance feedback.') + parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5, + help='Alpha value for interpolation in pseudo relevance feedback.') + + parser.add_argument('--fields', metavar="key=value", nargs='+', + help='Fields to search with assigned float weights.') + parser.add_argument('--dismax', action='store_true', default=False, + help='Use disjunction max queries when searching multiple fields.') + parser.add_argument('--dismax.tiebreaker', dest='tiebreaker', type=float, default=0.0, + help='The tiebreaker weight to use in disjunction max queries.') + + parser.add_argument('--stopwords', type=str, help='Path to file with customstopwords.') + + +if __name__ == "__main__": + JSimpleSearcher = autoclass('io.anserini.search.SimpleSearcher') + parser = argparse.ArgumentParser(description='Search a Lucene index.') + define_search_args(parser) + parser.add_argument('--topics', type=str, metavar='topic_name', required=True, + help="Name of topics. Available: robust04, robust05, core17, core18.") + parser.add_argument('--hits', type=int, metavar='num', + required=False, default=1000, help="Number of hits.") + parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value, + help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") + parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, + help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") + parser.add_argument('--output', type=str, metavar='path', + help="Path to output file.") + parser.add_argument('--max-passage', action='store_true', + default=False, help="Select only max passage from document.") + parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, + help="Final number of hits when selecting only max passage.") + parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', + help="Delimiter between docid and passage id.") + parser.add_argument('--batch-size', type=int, metavar='num', required=False, + default=1, help="Specify batch size to search the collection concurrently.") + parser.add_argument('--threads', type=int, metavar='num', required=False, + default=1, help="Maximum number of threads to use.") + parser.add_argument('--tokenizer', type=str, help='tokenizer used to preprocess topics') + parser.add_argument('--remove-duplicates', action='store_true', default=False, help="Remove duplicate docs.") + + args = parser.parse_args() + + query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) + topics = query_iterator.topics + + if not args.impact: + if os.path.exists(args.index): + # create searcher from index directory + searcher = SimpleSearcher(args.index) + else: + # create searcher from prebuilt index name + searcher = SimpleSearcher.from_prebuilt_index(args.index) + elif args.impact: + if os.path.exists(args.index): + searcher = ImpactSearcher(args.index, args.encoder, args.min_idf) + else: + searcher = ImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf) + + if args.language != 'en': + searcher.set_language(args.language) + + if not searcher: + exit() + + search_rankers = [] + + if args.qld: + search_rankers.append('qld') + searcher.set_qld() + elif args.bm25: + search_rankers.append('bm25') + set_bm25_parameters(searcher, args.index, args.k1, args.b) + + if args.rm3: + search_rankers.append('rm3') + searcher.set_rm3() + + fields = dict() + if args.fields: + fields = dict([pair.split('=') for pair in args.fields]) + print(f'Searching over fields: {fields}') + + query_generator = None + if args.dismax: + query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker) + print(f'Using dismax query generator with tiebreaker={args.tiebreaker}') + + if args.tokenizer != None: + analyzer = JWhiteSpaceAnalyzer() + searcher.set_analyzer(analyzer) + print(f'Using whitespace analyzer because of pretokenized topics') + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + print(f'Using {args.tokenizer} to preprocess topics') + + if args.stopwords: + analyzer = JDefaultEnglishAnalyzer.fromArguments('porter', False, args.stopwords) + searcher.set_analyzer(analyzer) + print(f'Using custom stopwords={args.stopwords}') + + # get re-ranker + use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0 + if use_prcl is True: + ranker = PseudoRelevanceClassifierReranker( + searcher.index_dir, args.vectorizer, args.prcl, r=args.r, n=args.n, alpha=args.alpha) + + # build output path + output_path = args.output + if output_path is None: + if use_prcl is True: + clf_rankers = [] + for t in args.prcl: + if t == ClassifierType.LR: + clf_rankers.append('lr') + elif t == ClassifierType.SVM: + clf_rankers.append('svm') + + r_str = f'prcl.r_{args.r}' + n_str = f'prcl.n_{args.n}' + a_str = f'prcl.alpha_{args.alpha}' + clf_str = 'prcl_' + '+'.join(clf_rankers) + tokens1 = ['run', args.topics, '+'.join(search_rankers)] + tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str] + output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt" + else: + tokens = ['run', args.topics, '+'.join(search_rankers), 'txt'] + output_path = '.'.join(tokens) + + print(f'Running {args.topics} topics, saving to {output_path}...') + tag = output_path[:-4] if args.output is None else 'Anserini' + + output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', + max_hits=args.hits, tag=tag, topics=topics, + use_max_passage=args.max_passage, + max_passage_delimiter=args.max_passage_delimiter, + max_passage_hits=args.max_passage_hits) + + with output_writer: + batch_topics = list() + batch_topic_ids = list() + for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): + if (args.tokenizer != None): + toks = tokenizer.tokenize(text) + text = ' ' + text = text.join(toks) + if args.batch_size <= 1 and args.threads <= 1: + if args.impact: + hits = searcher.search(text, args.hits, fields=fields) + else: + hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) + results = [(topic_id, hits)] + else: + batch_topic_ids.append(str(topic_id)) + batch_topics.append(text) + if (index + 1) % args.batch_size == 0 or \ + index == len(topics.keys()) - 1: + if args.impact: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields + ) + else: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.threads, + query_generator=query_generator, fields=fields + ) + results = [(id_, results[id_]) for id_ in batch_topic_ids] + batch_topic_ids.clear() + batch_topics.clear() + else: + continue + + for topic, hits in results: + # do rerank + if use_prcl and len(hits) > (args.r + args.n): + docids = [hit.docid.strip() for hit in hits] + scores = [hit.score for hit in hits] + scores, docids = ranker.rerank(docids, scores) + docid_score_map = dict(zip(docids, scores)) + for hit in hits: + hit.score = docid_score_map[hit.docid.strip()] + + if args.remove_duplicates: + seen_docids = set() + dedup_hits = [] + for hit in hits: + if hit.docid.strip() in seen_docids: + continue + seen_docids.add(hit.docid.strip()) + dedup_hits.append(hit) + hits = dedup_hits + + # write results + output_writer.write(topic, hits) + + results.clear() diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py new file mode 100644 index 000000000..cb87df18c --- /dev/null +++ b/pyserini/search/jass/_searcher.py @@ -0,0 +1,170 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +This module provides Pyserini's Python search interface to JASSv2. The main entry point is the ``JASSv2Searcher`` +class, which wraps the C++ ``JASS_anytime_api``. +""" + +import logging +from typing import Dict, List, Optional, Union + +logger = logging.getLogger(__name__) + +# Wrappers around JASS classes + +class JASSv2Searcher: + """Wrapper class for the ``JASS_anytime_api`` in JASSv2. + + Parameters + ---------- + index_dir : str + Path to JASS index directory. + """ + + def __init__(self, index_dir: str): + self.index_dir = index_dir + self.object = None #TODO + + # XXX: TODO: This is the Lucene version for reference... + def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, + fields=dict(), strip_segment_id=False, remove_dups=False) -> List[JSimpleSearcherResult]: + """Search the collection. + + Parameters + ---------- + q : Union[str, JQuery] + Query string or the ``JQuery`` objected. + k : int + Number of hits to return. + query_generator : JQueryGenerator + Generator to build queries. Set to ``None`` by default to use Anserini default. + fields : dict + Optional map of fields to search with associated boosts. + strip_segment_id : bool + Remove the .XXXXX suffix used to denote different segments from an document. + remove_dups : bool + Remove duplicate docids when writing final run output. + + Returns + ------- + List[JSimpleSearcherResult] + List of search results. + """ + + jfields = JHashMap() + for (field, boost) in fields.items(): + jfields.put(field, JFloat(boost)) + + hits = None + if query_generator: + if not fields: + hits = self.object.search(query_generator, q, k) + else: + hits = self.object.searchFields(query_generator, q, jfields, k) + elif isinstance(q, JQuery): + # Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just + # given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract + # all the query terms from the Lucene query, although this might yield unexpected behavior from the user's + # perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception + # here explicitly. + if self.is_using_rm3(): + raise NotImplementedError('RM3 incompatible with search using a Lucene query.') + if fields: + raise NotImplementedError('Cannot specify fields to search when using a Lucene query.') + hits = self.object.search(q, k) + else: + if not fields: + hits = self.object.search(q, k) + else: + hits = self.object.searchFields(q, jfields, k) + + docids = set() + filtered_hits = [] + + for hit in hits: + if strip_segment_id is True: + hit.docid = hit.docid.split('.')[0] + + if hit.docid in docids: + continue + + filtered_hits.append(hit) + + if remove_dups is True: + docids.add(hit.docid) + + return filtered_hits + + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, + query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[JSimpleSearcherResult]]: + """Search the collection concurrently for multiple queries, using multiple threads. + + Parameters + ---------- + queries : List[str] + List of query strings. + qids : List[str] + List of corresponding query ids. + k : int + Number of hits to return. + threads : int + Maximum number of threads to use. + query_generator : JQueryGenerator + Generator to build queries. Set to ``None`` by default to use Anserini default. + fields : dict + Optional map of fields to search with associated boosts. + + Returns + ------- + Dict[str, List[JSimpleSearcherResult]] + Dictionary holding the search results, with the query ids as keys and the corresponding lists of search + results as the values. + """ + query_strings = JArrayList() + qid_strings = JArrayList() + for query in queries: + query_strings.add(query) + + for qid in qids: + qid_strings.add(qid) + + jfields = JHashMap() + for (field, boost) in fields.items(): + jfields.put(field, JFloat(boost)) + + if query_generator: + if not fields: + results = self.object.batchSearch(query_generator, query_strings, qid_strings, int(k), int(threads)) + else: + results = self.object.batchSearchFields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) + else: + if not fields: + results = self.object.batchSearch(query_strings, qid_strings, int(k), int(threads)) + else: + results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) + return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} + + # XXX: TODO: This is the Anserini version but may be useful as reference + def convert_to_search_result(run: TrecRun, docid_to_search_result: Dict[str, JSimpleSearcherResult]) -> List[JSimpleSearcherResult]: + search_results = [] + + for _, _, docid, _, score, _ in run.to_numpy(): + search_result = docid_to_search_result[docid] + search_result.score = score + search_results.append(search_result) + + return search_results From fb4866bec4df76ed530501735133e4b039a4c746 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Fri, 11 Feb 2022 09:40:11 +0000 Subject: [PATCH 02/35] implemented init method to load index from pyjass --- pyserini/search/jass/_searcher.py | 62 ++++++++++++++++--------------- 1 file changed, 32 insertions(+), 30 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index cb87df18c..bd0bb1ffc 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -20,6 +20,7 @@ """ import logging +import pyjass from typing import Dict, List, Optional, Union logger = logging.getLogger(__name__) @@ -35,13 +36,17 @@ class JASSv2Searcher: Path to JASS index directory. """ - def __init__(self, index_dir: str): + def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir - self.object = None #TODO + self.object = pyjass.anytime() + index = self.object.load_index(version,index_dir) + if index != 0: + raise Exception('Unable to load index - error code' + str(index)) + # XXX: TODO: This is the Lucene version for reference... - def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGenerator = None, - fields=dict(), strip_segment_id=False, remove_dups=False) -> List[JSimpleSearcherResult]: + def search(self, q: str, k: int = 10, query_generator: JQueryGenerator = None, + fields=dict(), strip_segment_id=False, remove_dups=False) -> List[str]: """Search the collection. Parameters @@ -64,33 +69,30 @@ def search(self, q: Union[str, JQuery], k: int = 10, query_generator: JQueryGene List[JSimpleSearcherResult] List of search results. """ - - jfields = JHashMap() - for (field, boost) in fields.items(): - jfields.put(field, JFloat(boost)) - hits = None - if query_generator: - if not fields: - hits = self.object.search(query_generator, q, k) - else: - hits = self.object.searchFields(query_generator, q, jfields, k) - elif isinstance(q, JQuery): - # Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just - # given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract - # all the query terms from the Lucene query, although this might yield unexpected behavior from the user's - # perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception - # here explicitly. - if self.is_using_rm3(): - raise NotImplementedError('RM3 incompatible with search using a Lucene query.') - if fields: - raise NotImplementedError('Cannot specify fields to search when using a Lucene query.') - hits = self.object.search(q, k) - else: - if not fields: - hits = self.object.search(q, k) - else: - hits = self.object.searchFields(q, jfields, k) + self.object.set_top_k(k) + # if query_generator: + # if not fields: + # hits = self.object.search(query_generator, q, k) + # else: + # hits = self.object.searchFields(query_generator, q, jfields, k) + # elif isinstance(q, JQuery): + # # Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just + # # given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract + # # all the query terms from the Lucene query, although this might yield unexpected behavior from the user's + # # perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception + # # here explicitly. + # if self.is_using_rm3(): + # raise NotImplementedError('RM3 incompatible with search using a Lucene query.') + # if fields: + # raise NotImplementedError('Cannot specify fields to search when using a Lucene query.') + # hits = self.object.search(q, k) + # else: + # if not fields: + # hits = self.object.search(q, k) + + # else: + # hits = self.object.searchFields(q, jfields, k) docids = set() filtered_hits = [] From 04868b12a76a2d80313ad12b676b375aa1230681 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Sat, 12 Feb 2022 02:36:08 +0000 Subject: [PATCH 03/35] initial commit for search --- pyserini/search/jass/_searcher.py | 73 ++++--------------------------- 1 file changed, 9 insertions(+), 64 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index bd0bb1ffc..c0b94c420 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -45,74 +45,19 @@ def __init__(self, index_dir: str, version: int = 2): # XXX: TODO: This is the Lucene version for reference... - def search(self, q: str, k: int = 10, query_generator: JQueryGenerator = None, - fields=dict(), strip_segment_id=False, remove_dups=False) -> List[str]: - """Search the collection. - - Parameters - ---------- - q : Union[str, JQuery] - Query string or the ``JQuery`` objected. - k : int - Number of hits to return. - query_generator : JQueryGenerator - Generator to build queries. Set to ``None`` by default to use Anserini default. - fields : dict - Optional map of fields to search with associated boosts. - strip_segment_id : bool - Remove the .XXXXX suffix used to denote different segments from an document. - remove_dups : bool - Remove duplicate docids when writing final run output. - - Returns - ------- - List[JSimpleSearcherResult] - List of search results. - """ + def search(self, q: str, k: int = 10, rho: int = 10, + fields=dict(), strip_segment_id=False, remove_dups=False) -> List[pyjass.JASS_anytime_result]: + hits = None self.object.set_top_k(k) - # if query_generator: - # if not fields: - # hits = self.object.search(query_generator, q, k) - # else: - # hits = self.object.searchFields(query_generator, q, jfields, k) - # elif isinstance(q, JQuery): - # # Note that RM3 requires the notion of a query (string) to estimate the appropriate models. If we're just - # # given a Lucene query, it's unclear what the "query" is for this estimation. One possibility is to extract - # # all the query terms from the Lucene query, although this might yield unexpected behavior from the user's - # # perspective. Until we think through what exactly is the "right thing to do", we'll raise an exception - # # here explicitly. - # if self.is_using_rm3(): - # raise NotImplementedError('RM3 incompatible with search using a Lucene query.') - # if fields: - # raise NotImplementedError('Cannot specify fields to search when using a Lucene query.') - # hits = self.object.search(q, k) - # else: - # if not fields: - # hits = self.object.search(q, k) - - # else: - # hits = self.object.searchFields(q, jfields, k) - - docids = set() - filtered_hits = [] - - for hit in hits: - if strip_segment_id is True: - hit.docid = hit.docid.split('.')[0] - - if hit.docid in docids: - continue - - filtered_hits.append(hit) - - if remove_dups is True: - docids.add(hit.docid) - - return filtered_hits + self.object.set_postings_to_process(rho) + results = self.object.search(q) + + return results.results_list # TO-DO make it pyserini compatible + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, - query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[JSimpleSearcherResult]]: + query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[pyjass.JASS_anytime_result]]: """Search the collection concurrently for multiple queries, using multiple threads. Parameters From 5d97414e3f0f33431b3b1ee6ec7810e5156a0272 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 01:53:30 +0000 Subject: [PATCH 04/35] implemented crude version of search --- pyserini/search/jass/_searcher.py | 151 +++++++++++++++++------------- 1 file changed, 86 insertions(+), 65 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index c0b94c420..ee31f50f8 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -22,7 +22,8 @@ import logging import pyjass from typing import Dict, List, Optional, Union - +from pyserini.trectools import TrecRun +from pyserini.dsearch import DenseSearchResult logger = logging.getLogger(__name__) # Wrappers around JASS classes @@ -46,72 +47,92 @@ def __init__(self, index_dir: str, version: int = 2): # XXX: TODO: This is the Lucene version for reference... def search(self, q: str, k: int = 10, rho: int = 10, - fields=dict(), strip_segment_id=False, remove_dups=False) -> List[pyjass.JASS_anytime_result]: + fields=dict(), strip_segment_id=False, remove_dups=False) -> List[DenseSearchResult]: - hits = None + docid_score_pair = list() self.object.set_top_k(k) self.object.set_postings_to_process(rho) results = self.object.search(q) - - return results.results_list # TO-DO make it pyserini compatible - - - def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, - query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[pyjass.JASS_anytime_result]]: - """Search the collection concurrently for multiple queries, using multiple threads. - - Parameters - ---------- - queries : List[str] - List of query strings. - qids : List[str] - List of corresponding query ids. - k : int - Number of hits to return. - threads : int - Maximum number of threads to use. - query_generator : JQueryGenerator - Generator to build queries. Set to ``None`` by default to use Anserini default. - fields : dict - Optional map of fields to search with associated boosts. - - Returns - ------- - Dict[str, List[JSimpleSearcherResult]] - Dictionary holding the search results, with the query ids as keys and the corresponding lists of search - results as the values. - """ - query_strings = JArrayList() - qid_strings = JArrayList() + queries = results.results_list.split('\n') for query in queries: - query_strings.add(query) - - for qid in qids: - qid_strings.add(qid) - - jfields = JHashMap() - for (field, boost) in fields.items(): - jfields.put(field, JFloat(boost)) - - if query_generator: - if not fields: - results = self.object.batchSearch(query_generator, query_strings, qid_strings, int(k), int(threads)) - else: - results = self.object.batchSearchFields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) - else: - if not fields: - results = self.object.batchSearch(query_strings, qid_strings, int(k), int(threads)) - else: - results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) - return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} - - # XXX: TODO: This is the Anserini version but may be useful as reference - def convert_to_search_result(run: TrecRun, docid_to_search_result: Dict[str, JSimpleSearcherResult]) -> List[JSimpleSearcherResult]: - search_results = [] - - for _, _, docid, _, score, _ in run.to_numpy(): - search_result = docid_to_search_result[docid] - search_result.score = score - search_results.append(search_result) - - return search_results + qrel = query.split(' ') # split by space + if len(qrel) == 6: + docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way + + + return docid_score_pair + + + + # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, + # query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[pyjass.JASS_anytime_result]]: + # """Search the collection concurrently for multiple queries, using multiple threads. + + # Parameters + # ---------- + # queries : List[str] + # List of query strings. + # qids : List[str] + # List of corresponding query ids. + # k : int + # Number of hits to return. + # threads : int + # Maximum number of threads to use. + # query_generator : JQueryGenerator + # Generator to build queries. Set to ``None`` by default to use Anserini default. + # fields : dict + # Optional map of fields to search with associated boosts. + + # Returns + # ------- + # Dict[str, List[JSimpleSearcherResult]] + # Dictionary holding the search results, with the query ids as keys and the corresponding lists of search + # results as the values. + # """ + # query_strings = JArrayList() + # qid_strings = JArrayList() + # for query in queries: + # query_strings.add(query) + + # for qid in qids: + # qid_strings.add(qid) + + # jfields = JHashMap() + # for (field, boost) in fields.items(): + # jfields.put(field, JFloat(boost)) + + # if query_generator: + # if not fields: + # results = self.object.batchSearch(query_generator, query_strings, qid_strings, int(k), int(threads)) + # else: + # results = self.object.batchSearchFields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) + # else: + # if not fields: + # results = self.object.batchSearch(query_strings, qid_strings, int(k), int(threads)) + # else: + # results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) + # return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} + + # # XXX: TODO: This is the Anserini version but may be useful as reference + # def convert_to_search_result(run: TrecRun, docid_to_search_result: str]) -> List[JSimpleSearcherResult]: + # search_results = [] + + # for _, _, docid, _, score, _ in run.to_numpy(): + # search_result = docid_to_search_result[docid] + # search_result.score = score + # search_results.append(search_result) + + # return search_results + +# Quick and dirty test to load index, search and also get the hits + +def main(): + blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index + hits = blah.search('what is a lobster roll') + + for i in range(0, 5): + print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') + + +if __name__ == "__main__": + main() \ No newline at end of file From 79a7184641156423801da29fc803b72932500320 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 08:45:50 +0000 Subject: [PATCH 05/35] initial work on main class --- pyserini/search/jass/__main__.py | 25 +++---------------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index e9bf79fdf..6ddb065db 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -59,20 +59,10 @@ def set_bm25_parameters(searcher, index, k1=None, b=None): def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, - help="Path to Lucene index or name of prebuilt index.") - - parser.add_argument('--impact', action='store_true', help="Use Impact.") - parser.add_argument('--encoder', type=str, default=None, help="encoder name") - parser.add_argument('--min-idf', type=int, default=0, help="minimum idf") + help="Path to pyJass index") parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).") - parser.add_argument('--k1', type=float, help='BM25 k1 parameter.') - parser.add_argument('--b', type=float, help='BM25 b parameter.') - - parser.add_argument('--rm3', action='store_true', help="Use RM3") - parser.add_argument('--qld', action='store_true', help="Use QLD") - - parser.add_argument('--language', type=str, help='language code for BM25, e.g. zh for Chinese', default='en') + parser.add_argument('--rho', type=float, help='rho parameter.') parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[], help='Specify the classifier PseudoRelevanceClassifierReranker uses.') @@ -85,19 +75,12 @@ def define_search_args(parser): parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5, help='Alpha value for interpolation in pseudo relevance feedback.') - parser.add_argument('--fields', metavar="key=value", nargs='+', - help='Fields to search with assigned float weights.') - parser.add_argument('--dismax', action='store_true', default=False, - help='Use disjunction max queries when searching multiple fields.') - parser.add_argument('--dismax.tiebreaker', dest='tiebreaker', type=float, default=0.0, - help='The tiebreaker weight to use in disjunction max queries.') - parser.add_argument('--stopwords', type=str, help='Path to file with customstopwords.') if __name__ == "__main__": JSimpleSearcher = autoclass('io.anserini.search.SimpleSearcher') - parser = argparse.ArgumentParser(description='Search a Lucene index.') + parser = argparse.ArgumentParser(description='Search a pyJass index.') define_search_args(parser) parser.add_argument('--topics', type=str, metavar='topic_name', required=True, help="Name of topics. Available: robust04, robust05, core17, core18.") @@ -119,8 +102,6 @@ def define_search_args(parser): default=1, help="Specify batch size to search the collection concurrently.") parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, help="Maximum number of threads to use.") - parser.add_argument('--tokenizer', type=str, help='tokenizer used to preprocess topics') - parser.add_argument('--remove-duplicates', action='store_true', default=False, help="Remove duplicate docs.") args = parser.parse_args() From 2d5c1dc4c652023cca508beb28a3ca267ba3b3f5 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 08:46:15 +0000 Subject: [PATCH 06/35] made search more elegant --- pyserini/search/jass/_searcher.py | 68 ++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 20 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index ee31f50f8..8dadb6fcc 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -43,24 +43,42 @@ def __init__(self, index_dir: str, version: int = 2): index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) + + + # XXX: TODO: This is the Lucene version for reference... - def search(self, q: str, k: int = 10, rho: int = 10, - fields=dict(), strip_segment_id=False, remove_dups=False) -> List[DenseSearchResult]: - - docid_score_pair = list() + def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: + self.object.set_top_k(k) self.object.set_postings_to_process(rho) results = self.object.search(q) - queries = results.results_list.split('\n') - for query in queries: - qrel = query.split(' ') # split by space - if len(qrel) == 6: - docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way + return (self.convert_to_search_result(results.results_list)) + + + @abstractmethod + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: + """Perform batch search. + + Parameters + ---------- + queries : List[str] + List of queries. + qids : List[str] + List of query ids. + k : int + Number of results to return for each query. + threads : int + Number of threads to use. + + Returns + ------- + Dict[str, List[DenseSearchResult]] + Dict of query id to list of DenseSearchResult. + """ + raise NotImplementedError - - return docid_score_pair @@ -113,18 +131,28 @@ def search(self, q: str, k: int = 10, rho: int = 10, # results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) # return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} - # # XXX: TODO: This is the Anserini version but may be useful as reference - # def convert_to_search_result(run: TrecRun, docid_to_search_result: str]) -> List[JSimpleSearcherResult]: - # search_results = [] - # for _, _, docid, _, score, _ in run.to_numpy(): - # search_result = docid_to_search_result[docid] - # search_result.score = score - # search_results.append(search_result) +# Quick and dirty test to load index, search and also get the hits - # return search_results + def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: + """Process a pyJass query and return the results in a list of DenseSearchResult. -# Quick and dirty test to load index, search and also get the hits + Parameters + ---------- + query : str + Query string fromy pyjass. Multiple queries are stored as with new line token. + + Returns + ------- + List[DenseSearchResult] + List of DenseSearchResult which contains DocID and also the score from pyJass query. + """ + docid_score_pair = list() + queries = result_list.split('\n') + for query in queries: + qrel = query.split(' ') # split by space + if len(qrel) == 6: + docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way def main(): blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index From 0b1aa4d7f1e36458a1aa0af50df0a0fe8786b0eb Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 10:09:09 +0000 Subject: [PATCH 07/35] implemented convert_to_search --- pyserini/search/jass/_searcher.py | 42 +++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 8dadb6fcc..6aa1867b7 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -45,7 +45,27 @@ def __init__(self, index_dir: str, version: int = 2): raise Exception('Unable to load index - error code' + str(index)) - + def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: + """Process a pyJass query and return the results in a list of DenseSearchResult. + + Parameters + ---------- + query : str + Query string fromy pyjass. Multiple queries are stored as with new line token. + + Returns + ------- + List[DenseSearchResult] + List of DenseSearchResult which contains DocID and also the score from pyJass query. + """ + docid_score_pair = list() + queries = result_list.split('\n') + for query in queries: + qrel = query.split(' ') # split by space + if len(qrel) == 6: + docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way + + # XXX: TODO: This is the Lucene version for reference... @@ -134,26 +154,6 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads # Quick and dirty test to load index, search and also get the hits - def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: - """Process a pyJass query and return the results in a list of DenseSearchResult. - - Parameters - ---------- - query : str - Query string fromy pyjass. Multiple queries are stored as with new line token. - - Returns - ------- - List[DenseSearchResult] - List of DenseSearchResult which contains DocID and also the score from pyJass query. - """ - docid_score_pair = list() - queries = result_list.split('\n') - for query in queries: - qrel = query.split(' ') # split by space - if len(qrel) == 6: - docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way - def main(): blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index hits = blah.search('what is a lobster roll') From 1bde9262667b6e5ef89070591674066807ed7192 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 22:15:55 +0000 Subject: [PATCH 08/35] fix return error --- pyserini/search/jass/_searcher.py | 40 ++++++++++++++++--------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 6aa1867b7..ced6a04e6 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -65,6 +65,8 @@ def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: if len(qrel) == 6: docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way + return docid_score_pair + @@ -77,27 +79,27 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: return (self.convert_to_search_result(results.results_list)) - @abstractmethod - def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: - """Perform batch search. + # @abstractmethod + # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: + # """Perform batch search. - Parameters - ---------- - queries : List[str] - List of queries. - qids : List[str] - List of query ids. - k : int - Number of results to return for each query. - threads : int - Number of threads to use. + # Parameters + # ---------- + # queries : List[str] + # List of queries. + # qids : List[str] + # List of query ids. + # k : int + # Number of results to return for each query. + # threads : int + # Number of threads to use. - Returns - ------- - Dict[str, List[DenseSearchResult]] - Dict of query id to list of DenseSearchResult. - """ - raise NotImplementedError + # Returns + # ------- + # Dict[str, List[DenseSearchResult]] + # Dict of query id to list of DenseSearchResult. + # """ + # raise NotImplementedError From cfed6c64db163f9b188ef4220408579a9aa2d41f Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 23:13:24 +0000 Subject: [PATCH 09/35] fix of handling query_id mapping --- pyserini/search/jass/_searcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index ced6a04e6..c9cad4520 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -75,7 +75,7 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: self.object.set_top_k(k) self.object.set_postings_to_process(rho) - results = self.object.search(q) + results = self.object.search("0:"+q) # appending "0: to handle jass' requirements" return (self.convert_to_search_result(results.results_list)) From 047c9d0648e0e5d717d612dec8ae02a35c427828 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Mon, 14 Feb 2022 23:19:42 +0000 Subject: [PATCH 10/35] if we see digit and a semicolon - consume it --- pyserini/search/jass/_searcher.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index c9cad4520..e8674844f 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -75,7 +75,10 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: self.object.set_top_k(k) self.object.set_postings_to_process(rho) - results = self.object.search("0:"+q) # appending "0: to handle jass' requirements" + if q[0].isdigit() and q[1] == ':': + results = self.object.search(q) + else: + results = self.object.search("0:"+q) # appending "0: to handle jass' requirements" return (self.convert_to_search_result(results.results_list)) @@ -158,7 +161,7 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: def main(): blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index - hits = blah.search('what is a lobster roll') + hits = blah.search('2:what is a lobster roll') for i in range(0, 5): print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') From 14563666b57fb17626323630da0e0d69b675aa79 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Tue, 15 Feb 2022 07:17:27 +0000 Subject: [PATCH 11/35] initial start for batch_search --- pyserini/search/jass/_searcher.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index e8674844f..86a96c16c 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -82,6 +82,27 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: return (self.convert_to_search_result(results.results_list)) + # def list_to_cvector(s: List[str]) -> pyjass.string_vector: + + # return(pyjass.string_vector(s)) + + def zip_two_lists(l1: List[str], l2: List[str]) -> List[str]: + """Zip two lists together and return a list of tuples. + + Parameters + ---------- + l1 : List[str] + First list to zip. + l2 : List[str] + Second list to zip. + + Returns + ------- + List[str] + List of tuples of the two lists. + """ + return list(zip(l1, l2)) + # @abstractmethod # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: # """Perform batch search. @@ -166,6 +187,11 @@ def main(): for i in range(0, 5): print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') + list1 = ['blah','blah2'] + list2 = ['101','102'] + zipped = blah.zip_two_lists(list1, list2) + print(zipped) + if __name__ == "__main__": main() \ No newline at end of file From b411e818c24f08b9abff1101a29749a28d7890a3 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Tue, 15 Feb 2022 10:43:11 +0000 Subject: [PATCH 12/35] implemented batch_search for pyjass --- pyserini/search/jass/_searcher.py | 76 +++++++++++++++++++++++-------- 1 file changed, 57 insertions(+), 19 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 86a96c16c..f94ab1e11 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -82,27 +82,63 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: return (self.convert_to_search_result(results.results_list)) - # def list_to_cvector(s: List[str]) -> pyjass.string_vector: + def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.string_vector: + """Convert a list of queries to a c++ string_vector. + + Parameters + ---------- + qids : List[str] + List of query ids. + queries : List[str] + List of queries. + + Returns + ------- + pyjass.string_vector + c++ string_vector to be consumed by Jass. - # return(pyjass.string_vector(s)) + """ + return(pyjass.string_vector([str(x[0] + ":") + x[1] for x in zip(qids, queries)])) - def zip_two_lists(l1: List[str], l2: List[str]) -> List[str]: - """Zip two lists together and return a list of tuples. + + + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: + + """Perform batch search. Parameters ---------- - l1 : List[str] - First list to zip. - l2 : List[str] - Second list to zip. + queries : List[str] + List of queries. + qids : List[str] + List of query ids. + k : int + Number of results to return. + rho : int + Value of rho to use. + threads : int + Number of threads to use. Returns ------- - List[str] - List of tuples of the two lists. + Dict[str, List[DenseSearchResult]] + Dictionary of query id to list of DenseSearchResult. """ - return list(zip(l1, l2)) - + + self.object.set_top_k(k) + output = dict() + self.object.set_postings_to_process(rho) + results = self.object.threaded_search(self.__list_to_strvector(qids, queries), threads) + for i in range(len(results)): + if len(results[i].results) > 0: + for key in results[i].results.asdict().keys(): + output[key] = self.convert_to_search_result(results[i].results[key].results_list) + + return output + + + + # @abstractmethod # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: # """Perform batch search. @@ -182,15 +218,17 @@ def zip_two_lists(l1: List[str], l2: List[str]) -> List[str]: def main(): blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index - hits = blah.search('2:what is a lobster roll') + queries = ['new york pizza','what is a lobster roll','malaysia is awesome'] + qid = ['101','102','103'] + hits = blah.batch_search(queries,qid,10,2,3) + print(hits) + + + + # for i in range(0, 5): + # print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') - for i in range(0, 5): - print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') - list1 = ['blah','blah2'] - list2 = ['101','102'] - zipped = blah.zip_two_lists(list1, list2) - print(zipped) if __name__ == "__main__": From fecc7037e48800dca80aeda149a7ec53b245bc5f Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Tue, 15 Feb 2022 10:45:01 +0000 Subject: [PATCH 13/35] cleaned up implementation --- pyserini/search/jass/_searcher.py | 77 ------------------------------- 1 file changed, 77 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index f94ab1e11..2bf4b6003 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -69,8 +69,6 @@ def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: - - # XXX: TODO: This is the Lucene version for reference... def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: self.object.set_top_k(k) @@ -139,81 +137,6 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in - # @abstractmethod - # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: - # """Perform batch search. - - # Parameters - # ---------- - # queries : List[str] - # List of queries. - # qids : List[str] - # List of query ids. - # k : int - # Number of results to return for each query. - # threads : int - # Number of threads to use. - - # Returns - # ------- - # Dict[str, List[DenseSearchResult]] - # Dict of query id to list of DenseSearchResult. - # """ - # raise NotImplementedError - - - - - # def batch_search(self, queries: List[str], qids: List[str], k: int = 10, threads: int = 1, - # query_generator: JQueryGenerator = None, fields = dict()) -> Dict[str, List[pyjass.JASS_anytime_result]]: - # """Search the collection concurrently for multiple queries, using multiple threads. - - # Parameters - # ---------- - # queries : List[str] - # List of query strings. - # qids : List[str] - # List of corresponding query ids. - # k : int - # Number of hits to return. - # threads : int - # Maximum number of threads to use. - # query_generator : JQueryGenerator - # Generator to build queries. Set to ``None`` by default to use Anserini default. - # fields : dict - # Optional map of fields to search with associated boosts. - - # Returns - # ------- - # Dict[str, List[JSimpleSearcherResult]] - # Dictionary holding the search results, with the query ids as keys and the corresponding lists of search - # results as the values. - # """ - # query_strings = JArrayList() - # qid_strings = JArrayList() - # for query in queries: - # query_strings.add(query) - - # for qid in qids: - # qid_strings.add(qid) - - # jfields = JHashMap() - # for (field, boost) in fields.items(): - # jfields.put(field, JFloat(boost)) - - # if query_generator: - # if not fields: - # results = self.object.batchSearch(query_generator, query_strings, qid_strings, int(k), int(threads)) - # else: - # results = self.object.batchSearchFields(query_generator, query_strings, qid_strings, int(k), int(threads), jfields) - # else: - # if not fields: - # results = self.object.batchSearch(query_strings, qid_strings, int(k), int(threads)) - # else: - # results = self.object.batchSearchFields(query_strings, qid_strings, int(k), int(threads), jfields) - # return {r.getKey(): r.getValue() for r in results.entrySet().toArray()} - - # Quick and dirty test to load index, search and also get the hits def main(): From cab9da5bb5bf83e85a0ec426bf03372ab8042181 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 01:17:35 +0000 Subject: [PATCH 14/35] updated to support pyjass 0.2a7 --- pyserini/search/jass/_searcher.py | 54 +++++++++++++++++++++++-------- 1 file changed, 40 insertions(+), 14 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 2bf4b6003..14c9cdca7 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -19,15 +19,23 @@ class, which wraps the C++ ``JASS_anytime_api``. """ +from dataclasses import dataclass import logging import pyjass from typing import Dict, List, Optional, Union from pyserini.trectools import TrecRun -from pyserini.dsearch import DenseSearchResult logger = logging.getLogger(__name__) # Wrappers around JASS classes +@dataclass +class JASSv2SearcherResult: + docid: str # doc id + score: float # score in flaot + # query: str #query + # postings_processed: int # no of posting processed + + class JASSv2Searcher: """Wrapper class for the ``JASS_anytime_api`` in JASSv2. @@ -40,12 +48,12 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() - index = self.object.load_index(version,index_dir) + index = self.object.load_index(version,'/home/pradeesh') if index != 0: raise Exception('Unable to load index - error code' + str(index)) - def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: + def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult]: """Process a pyJass query and return the results in a list of DenseSearchResult. Parameters @@ -63,13 +71,30 @@ def convert_to_search_result(self, result_list:str) -> List[DenseSearchResult]: for query in queries: qrel = query.split(' ') # split by space if len(qrel) == 6: - docid_score_pair.append(DenseSearchResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way + docid_score_pair.append(JASSv2SearcherResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way return docid_score_pair - def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: + def search(self, q: str, k: int = 10, rho: int = 10) -> List[JASSv2SearcherResult]: + """Search the collection for a single query. + + Parameters + ---------- + q : str + Query string. + k : int + Number of results to return. + rho : int + Value of rho to use. + + Returns + ------- + List[JASSv2SearcherResult] + List of search results. + + """ self.object.set_top_k(k) self.object.set_postings_to_process(rho) @@ -80,7 +105,7 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[DenseSearchResult]: return (self.convert_to_search_result(results.results_list)) - def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.string_vector: + def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.JASS_string_vector: """Convert a list of queries to a c++ string_vector. Parameters @@ -96,13 +121,13 @@ def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.strin c++ string_vector to be consumed by Jass. """ - return(pyjass.string_vector([str(x[0] + ":") + x[1] for x in zip(qids, queries)])) + return(pyjass.JASS_string_vector([str(x[0] + ":") + x[1] for x in zip(qids, queries)])) - def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = 10, threads: int = 1) -> Dict[str, List[DenseSearchResult]]: + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = 10, threads: int = 1) -> Dict[str, List[JASSv2SearcherResult]]: - """Perform batch search. + """Search the collection concurrently for multiple queries, using multiple threads. Parameters ---------- @@ -119,8 +144,9 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in Returns ------- - Dict[str, List[DenseSearchResult]] - Dictionary of query id to list of DenseSearchResult. + Dict[str, List[JASSv2SearcherResult]] + Dictionary holding the search results, with the query ids as keys and the corresponding lists of search + results as the values. """ self.object.set_top_k(k) @@ -141,9 +167,9 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in def main(): blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index - queries = ['new york pizza','what is a lobster roll','malaysia is awesome'] - qid = ['101','102','103'] - hits = blah.batch_search(queries,qid,10,2,3) + queries = ['new york pizza','what is a lobster roll','malaysia is awesome'] # queries to search + qid = ['101','102','103'] #queries id + hits = blah.batch_search(queries,qid,10,2,3) # print(hits) From cd887c64e33cde1ca2c3fc83093acbf2d4000b37 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 01:17:55 +0000 Subject: [PATCH 15/35] added support for jass_search --- pyserini/search/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyserini/search/__init__.py b/pyserini/search/__init__.py index fea940d67..b2672c3da 100644 --- a/pyserini/search/__init__.py +++ b/pyserini/search/__init__.py @@ -20,9 +20,10 @@ from .lucene import SimpleNearestNeighborSearcher, JSimpleNearestNeighborSearcherResult from .lucene import JImpactSearcherResult, LuceneImpactSearcher from ._deprecated import SimpleSearcher, ImpactSearcher, SimpleFusionSearcher +from .jass import JASSv2Searcher __all__ = ['JQuery', 'LuceneSimilarities', 'LuceneFusionSearcher', 'LuceneSearcher', 'JSimpleSearcherResult', 'SimpleNearestNeighborSearcher', 'JSimpleNearestNeighborSearcherResult', 'LuceneImpactSearcher', 'JImpactSearcherResult', 'JDisjunctionMaxQueryGenerator', 'get_topics', 'get_topics_with_reader', 'get_qrels_file', 'get_qrels', - 'SimpleSearcher', 'ImpactSearcher', 'SimpleFusionSearcher'] + 'SimpleSearcher', 'ImpactSearcher', 'SimpleFusionSearcher','JASSv2Searcher'] From 249ebe64bb1f981e0b3876556205605ebbb57470 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 02:14:37 +0000 Subject: [PATCH 16/35] getting main driver class to work --- pyserini/search/jass/__init__.py | 20 ++++ pyserini/search/jass/__main__.py | 190 ++++-------------------------- pyserini/search/jass/_searcher.py | 3 +- pyserini/search/jass/test.py | 10 ++ 4 files changed, 53 insertions(+), 170 deletions(-) create mode 100644 pyserini/search/jass/__init__.py create mode 100644 pyserini/search/jass/test.py diff --git a/pyserini/search/jass/__init__.py b/pyserini/search/jass/__init__.py new file mode 100644 index 000000000..6f6ce7ff1 --- /dev/null +++ b/pyserini/search/jass/__init__.py @@ -0,0 +1,20 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from ._searcher import JASSv2Searcher , JASSv2SearcherResult + + +__all__ = ['JASSv2Searcher', 'JASSv2SearcherResult'] \ No newline at end of file diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index 6ddb065db..39389ebfe 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -16,73 +16,26 @@ import argparse import os +import errno from tqdm import tqdm -from transformers import AutoTokenizer -from pyserini.analysis import JDefaultEnglishAnalyzer, JWhiteSpaceAnalyzer from pyserini.output_writer import OutputFormat, get_output_writer -from pyserini.pyclass import autoclass from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search import ImpactSearcher, SimpleSearcher, JDisjunctionMaxQueryGenerator -from pyserini.search.lucene.reranker import ClassifierType, PseudoRelevanceClassifierReranker +from pyserini.search.jass import JASSv2Searcher -def set_bm25_parameters(searcher, index, k1=None, b=None): - if k1 is not None or b is not None: - if k1 is None or b is None: - print('Must set *both* k1 and b for BM25!') - exit() - print(f'Setting BM25 parameters: k1={k1}, b={b}') - searcher.set_bm25(k1, b) - else: - # Automatically set bm25 parameters based on known index: - if index == 'msmarco-passage' or index == 'msmarco-passage-slim': - print('MS MARCO passage: setting k1=0.82, b=0.68') - searcher.set_bm25(0.82, 0.68) - elif index == 'msmarco-passage-expanded': - print('MS MARCO passage w/ doc2query-T5 expansion: setting k1=2.18, b=0.86') - searcher.set_bm25(2.18, 0.86) - elif index == 'msmarco-doc' or index == 'msmarco-doc-slim': - print('MS MARCO doc: setting k1=4.46, b=0.82') - searcher.set_bm25(4.46, 0.82) - elif index == 'msmarco-doc-per-passage' or index == 'msmarco-doc-per-passage-slim': - print('MS MARCO doc, per passage: setting k1=2.16, b=0.61') - searcher.set_bm25(2.16, 0.61) - elif index == 'msmarco-doc-expanded-per-doc': - print('MS MARCO doc w/ doc2query-T5 (per doc) expansion: setting k1=4.68, b=0.87') - searcher.set_bm25(4.68, 0.87) - elif index == 'msmarco-doc-expanded-per-passage': - print('MS MARCO doc w/ doc2query-T5 (per passage) expansion: setting k1=2.56, b=0.59') - searcher.set_bm25(2.56, 0.59) def define_search_args(parser): - parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, + parser.add_argument('--index', type=str, default='/home/pradeesh', metavar='path to index or index name', required=False, help="Path to pyJass index") - - parser.add_argument('--bm25', action='store_true', default=True, help="Use BM25 (default).") - parser.add_argument('--rho', type=float, help='rho parameter.') - - parser.add_argument('--prcl', type=ClassifierType, nargs='+', default=[], - help='Specify the classifier PseudoRelevanceClassifierReranker uses.') - parser.add_argument('--prcl.vectorizer', dest='vectorizer', type=str, - help='Type of vectorizer. Available: TfidfVectorizer, BM25Vectorizer.') - parser.add_argument('--prcl.r', dest='r', type=int, default=10, - help='Number of positive labels in pseudo relevance feedback.') - parser.add_argument('--prcl.n', dest='n', type=int, default=100, - help='Number of negative labels in pseudo relevance feedback.') - parser.add_argument('--prcl.alpha', dest='alpha', type=float, default=0.5, - help='Alpha value for interpolation in pseudo relevance feedback.') - - - + parser.add_argument('--rho', type=int, help='rho parameter.') if __name__ == "__main__": - JSimpleSearcher = autoclass('io.anserini.search.SimpleSearcher') parser = argparse.ArgumentParser(description='Search a pyJass index.') define_search_args(parser) - parser.add_argument('--topics', type=str, metavar='topic_name', required=True, + parser.add_argument('--topics', type=str, default='/home/pradeesh/query/sample_queries.tsv',metavar='topic_name', required=False, help="Name of topics. Available: robust04, robust05, core17, core18.") parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") @@ -92,141 +45,59 @@ def define_search_args(parser): help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") parser.add_argument('--output', type=str, metavar='path', help="Path to output file.") - parser.add_argument('--max-passage', action='store_true', - default=False, help="Select only max passage from document.") - parser.add_argument('--max-passage-hits', type=int, metavar='num', required=False, default=100, - help="Final number of hits when selecting only max passage.") - parser.add_argument('--max-passage-delimiter', type=str, metavar='str', required=False, default='#', - help="Delimiter between docid and passage id.") parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, help="Specify batch size to search the collection concurrently.") parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, help="Maximum number of threads to use.") - args = parser.parse_args() query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) topics = query_iterator.topics - if not args.impact: - if os.path.exists(args.index): - # create searcher from index directory - searcher = SimpleSearcher(args.index) - else: - # create searcher from prebuilt index name - searcher = SimpleSearcher.from_prebuilt_index(args.index) - elif args.impact: - if os.path.exists(args.index): - searcher = ImpactSearcher(args.index, args.encoder, args.min_idf) - else: - searcher = ImpactSearcher.from_prebuilt_index(args.index, args.encoder, args.min_idf) - - if args.language != 'en': - searcher.set_language(args.language) + if os.path.exists(args.index): + # create searcher from index directory + print(args.index) + searcher = JASSv2Searcher('/home/pradeesh',2) + else: + # TODO: handle pre_build index if it's not found but we will throw file not found + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), args.index) if not searcher: exit() - search_rankers = [] - - if args.qld: - search_rankers.append('qld') - searcher.set_qld() - elif args.bm25: - search_rankers.append('bm25') - set_bm25_parameters(searcher, args.index, args.k1, args.b) - - if args.rm3: - search_rankers.append('rm3') - searcher.set_rm3() - fields = dict() if args.fields: fields = dict([pair.split('=') for pair in args.fields]) print(f'Searching over fields: {fields}') - query_generator = None - if args.dismax: - query_generator = JDisjunctionMaxQueryGenerator(args.tiebreaker) - print(f'Using dismax query generator with tiebreaker={args.tiebreaker}') - - if args.tokenizer != None: - analyzer = JWhiteSpaceAnalyzer() - searcher.set_analyzer(analyzer) - print(f'Using whitespace analyzer because of pretokenized topics') - tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) - print(f'Using {args.tokenizer} to preprocess topics') - - if args.stopwords: - analyzer = JDefaultEnglishAnalyzer.fromArguments('porter', False, args.stopwords) - searcher.set_analyzer(analyzer) - print(f'Using custom stopwords={args.stopwords}') - - # get re-ranker - use_prcl = args.prcl and len(args.prcl) > 0 and args.alpha > 0 - if use_prcl is True: - ranker = PseudoRelevanceClassifierReranker( - searcher.index_dir, args.vectorizer, args.prcl, r=args.r, n=args.n, alpha=args.alpha) # build output path output_path = args.output if output_path is None: - if use_prcl is True: - clf_rankers = [] - for t in args.prcl: - if t == ClassifierType.LR: - clf_rankers.append('lr') - elif t == ClassifierType.SVM: - clf_rankers.append('svm') - - r_str = f'prcl.r_{args.r}' - n_str = f'prcl.n_{args.n}' - a_str = f'prcl.alpha_{args.alpha}' - clf_str = 'prcl_' + '+'.join(clf_rankers) - tokens1 = ['run', args.topics, '+'.join(search_rankers)] - tokens2 = [args.vectorizer, clf_str, r_str, n_str, a_str] - output_path = '.'.join(tokens1) + '-' + '-'.join(tokens2) + ".txt" - else: - tokens = ['run', args.topics, '+'.join(search_rankers), 'txt'] - output_path = '.'.join(tokens) + tokens = ['run', args.topics, '+'.join(['rho',args.rho]), 'txt'] # we use the rho output + output_path = '.'.join(tokens) print(f'Running {args.topics} topics, saving to {output_path}...') - tag = output_path[:-4] if args.output is None else 'Anserini' + tag = output_path[:-4] if args.output is None else 'pyJass' output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', - max_hits=args.hits, tag=tag, topics=topics, - use_max_passage=args.max_passage, - max_passage_delimiter=args.max_passage_delimiter, - max_passage_hits=args.max_passage_hits) + max_hits=args.hits, tag=tag, topics=topics) with output_writer: batch_topics = list() batch_topic_ids = list() for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): - if (args.tokenizer != None): - toks = tokenizer.tokenize(text) - text = ' ' - text = text.join(toks) if args.batch_size <= 1 and args.threads <= 1: - if args.impact: - hits = searcher.search(text, args.hits, fields=fields) - else: - hits = searcher.search(text, args.hits, query_generator=query_generator, fields=fields) + hits = searcher.search(text, args.hits, fields=fields) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) batch_topics.append(text) if (index + 1) % args.batch_size == 0 or \ - index == len(topics.keys()) - 1: - if args.impact: - results = searcher.batch_search( - batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields - ) - else: - results = searcher.batch_search( - batch_topics, batch_topic_ids, args.hits, args.threads, - query_generator=query_generator, fields=fields - ) + index == len(topics.keys()) - 1: + results = searcher.batch_search( + batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields + ) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() @@ -234,25 +105,6 @@ def define_search_args(parser): continue for topic, hits in results: - # do rerank - if use_prcl and len(hits) > (args.r + args.n): - docids = [hit.docid.strip() for hit in hits] - scores = [hit.score for hit in hits] - scores, docids = ranker.rerank(docids, scores) - docid_score_map = dict(zip(docids, scores)) - for hit in hits: - hit.score = docid_score_map[hit.docid.strip()] - - if args.remove_duplicates: - seen_docids = set() - dedup_hits = [] - for hit in hits: - if hit.docid.strip() in seen_docids: - continue - seen_docids.add(hit.docid.strip()) - dedup_hits.append(hit) - hits = dedup_hits - # write results output_writer.write(topic, hits) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 14c9cdca7..23cb0a282 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -48,7 +48,8 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() - index = self.object.load_index(version,'/home/pradeesh') + print(self.object) + index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) diff --git a/pyserini/search/jass/test.py b/pyserini/search/jass/test.py new file mode 100644 index 000000000..a0760fd23 --- /dev/null +++ b/pyserini/search/jass/test.py @@ -0,0 +1,10 @@ +from pyserini.search import JASSv2Searcher +import jass + + + +searcher = JASSv2Searcher('msmarco-passage') +hits = searcher.search('what is a lobster roll?') + +for i in range(0, 10): + print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') \ No newline at end of file From 1e1ec1d67585d51b869fdfc76b6b278da5971cfa Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 03:39:07 +0000 Subject: [PATCH 17/35] some final fixes --- pyserini/search/jass/__main__.py | 28 +++++++++++++--------------- pyserini/search/jass/_searcher.py | 27 ++------------------------- 2 files changed, 15 insertions(+), 40 deletions(-) diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index 39389ebfe..7357a225e 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -14,28 +14,28 @@ # limitations under the License. # +import pyjass import argparse import os import errno - from tqdm import tqdm from pyserini.output_writer import OutputFormat, get_output_writer from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search.jass import JASSv2Searcher +from pyserini.search import JASSv2Searcher def define_search_args(parser): - parser.add_argument('--index', type=str, default='/home/pradeesh', metavar='path to index or index name', required=False, + parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to pyJass index") - parser.add_argument('--rho', type=int, help='rho parameter.') + parser.add_argument('--rho', type=int, default=10, help='rho parameter.') if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search a pyJass index.') define_search_args(parser) - parser.add_argument('--topics', type=str, default='/home/pradeesh/query/sample_queries.tsv',metavar='topic_name', required=False, + parser.add_argument('--topics', type=str, metavar='topic_name', required=True, help="Name of topics. Available: robust04, robust05, core17, core18.") parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.") @@ -43,7 +43,7 @@ def define_search_args(parser): help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") - parser.add_argument('--output', type=str, metavar='path', + parser.add_argument('--output', type=str, default='/home/prasys/output.txt', metavar='path', help="Path to output file.") parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, help="Specify batch size to search the collection concurrently.") @@ -56,8 +56,7 @@ def define_search_args(parser): if os.path.exists(args.index): # create searcher from index directory - print(args.index) - searcher = JASSv2Searcher('/home/pradeesh',2) + searcher = JASSv2Searcher(args.index,2) else: # TODO: handle pre_build index if it's not found but we will throw file not found raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), args.index) @@ -66,15 +65,15 @@ def define_search_args(parser): exit() fields = dict() - if args.fields: - fields = dict([pair.split('=') for pair in args.fields]) - print(f'Searching over fields: {fields}') + # if args.fields: + # fields = dict([pair.split('=') for pair in args.fields]) + # print(f'Searching over fields: {fields}') # build output path output_path = args.output if output_path is None: - tokens = ['run', args.topics, '+'.join(['rho',args.rho]), 'txt'] # we use the rho output + tokens = ['run', args.topics, '+'.join(['rho',str(args.rho)]), 'txt'] # we use the rho output output_path = '.'.join(tokens) print(f'Running {args.topics} topics, saving to {output_path}...') @@ -88,7 +87,7 @@ def define_search_args(parser): batch_topic_ids = list() for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))): if args.batch_size <= 1 and args.threads <= 1: - hits = searcher.search(text, args.hits, fields=fields) + hits = searcher.search(text, args.hits, args.rho) results = [(topic_id, hits)] else: batch_topic_ids.append(str(topic_id)) @@ -96,8 +95,7 @@ def define_search_args(parser): if (index + 1) % args.batch_size == 0 or \ index == len(topics.keys()) - 1: results = searcher.batch_search( - batch_topics, batch_topic_ids, args.hits, args.threads, fields=fields - ) + batch_topics, batch_topic_ids, args.hits, args.rho, args.threads) results = [(id_, results[id_]) for id_ in batch_topic_ids] batch_topic_ids.clear() batch_topics.clear() diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 23cb0a282..1ac3e6de5 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -32,6 +32,7 @@ class JASSv2SearcherResult: docid: str # doc id score: float # score in flaot + #TODO Implement the follow attributes specially for JASSv2 # query: str #query # postings_processed: int # no of posting processed @@ -48,10 +49,9 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() - print(self.object) index = self.object.load_index(version,index_dir) if index != 0: - raise Exception('Unable to load index - error code' + str(index)) + raise Exception('Unable to load index - error code' + str(index)) def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult]: @@ -160,26 +160,3 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in output[key] = self.convert_to_search_result(results[i].results[key].results_list) return output - - - - -# Quick and dirty test to load index, search and also get the hits - -def main(): - blah = JASSv2Searcher('/home/pradeesh') # collection to Jass pre-built Index - queries = ['new york pizza','what is a lobster roll','malaysia is awesome'] # queries to search - qid = ['101','102','103'] #queries id - hits = blah.batch_search(queries,qid,10,2,3) # - print(hits) - - - - # for i in range(0, 5): - # print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}') - - - - -if __name__ == "__main__": - main() \ No newline at end of file From c882ebd78c0caa7568fd295d577c414f23a4c437 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Wed, 16 Feb 2022 16:55:34 +1100 Subject: [PATCH 18/35] Minor clean up --- pyserini/search/jass/__main__.py | 17 +++++-------- pyserini/search/jass/_searcher.py | 42 ++++++++++++++++++++----------- 2 files changed, 33 insertions(+), 26 deletions(-) diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index 7357a225e..6d159a6fb 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -25,12 +25,10 @@ from pyserini.search import JASSv2Searcher - - def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to pyJass index") - parser.add_argument('--rho', type=int, default=10, help='rho parameter.') + parser.add_argument('--rho', type=int, default=1000000000, help='rho: how many postings to process') if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search a pyJass index.') @@ -43,7 +41,7 @@ def define_search_args(parser): help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}") parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value, help=f"Format of output. Available: {[x.value for x in list(OutputFormat)]}") - parser.add_argument('--output', type=str, default='/home/prasys/output.txt', metavar='path', + parser.add_argument('--output', type=str, metavar='path', help="Path to output file.") parser.add_argument('--batch-size', type=int, metavar='num', required=False, default=1, help="Specify batch size to search the collection concurrently.") @@ -56,7 +54,7 @@ def define_search_args(parser): if os.path.exists(args.index): # create searcher from index directory - searcher = JASSv2Searcher(args.index,2) + searcher = JASSv2Searcher(args.index, 2) else: # TODO: handle pre_build index if it's not found but we will throw file not found raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), args.index) @@ -64,16 +62,13 @@ def define_search_args(parser): if not searcher: exit() - fields = dict() - # if args.fields: - # fields = dict([pair.split('=') for pair in args.fields]) - # print(f'Searching over fields: {fields}') - + # JASS does not (yet) support field-based retrieval + fields = None # build output path output_path = args.output if output_path is None: - tokens = ['run', args.topics, '+'.join(['rho',str(args.rho)]), 'txt'] # we use the rho output + tokens = ['run', args.topics, '_'.join(['rho',str(args.rho)]), 'txt'] # we use the rho output output_path = '.'.join(tokens) print(f'Running {args.topics} topics, saving to {output_path}...') diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 1ac3e6de5..cb9fea74e 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -31,13 +31,20 @@ @dataclass class JASSv2SearcherResult: docid: str # doc id - score: float # score in flaot - #TODO Implement the follow attributes specially for JASSv2 + score: float # score in float + #TODO Implement the following attributes specially for JASSv2 # query: str #query # postings_processed: int # no of posting processed class JASSv2Searcher: + + # Constants + EXPECTED_ENTRIES = 6 + DOCID_POS = 2 + SCORE_POS = 4 + ONE_BILLION = 1000000000 + """Wrapper class for the ``JASS_anytime_api`` in JASSv2. Parameters @@ -64,21 +71,22 @@ def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult Returns ------- - List[DenseSearchResult] - List of DenseSearchResult which contains DocID and also the score from pyJass query. + List[JASSv2SearcherResult] + List of JASSv2SearcherResult which contains the DocID and also the score pair. """ docid_score_pair = list() - queries = result_list.split('\n') - for query in queries: - qrel = query.split(' ') # split by space - if len(qrel) == 6: - docid_score_pair.append(JASSv2SearcherResult(qrel[2], float(qrel[4]))) # make it as a dense object so pyserini downstream tasks know how to handle - quick way - + results = result_list.split('\n') + for res in results: + # Split by space. We expect the `trec` format, bail out if we don't get it + result_data = res.split(' ') + if len(result_data) == self.EXPECTED_ENTRIES: + # All is well, append the [docid, score] tuple. + docid_score_pair.append(JASSv2SearcherResult(result_data[self.DOCID_POS], float(result_data[self.SCORE_POS]))) return docid_score_pair - def search(self, q: str, k: int = 10, rho: int = 10) -> List[JASSv2SearcherResult]: + def search(self, q: str, k: int = 10, rho: int = ONE_BILLION) -> List[JASSv2SearcherResult]: """Search the collection for a single query. Parameters @@ -99,14 +107,18 @@ def search(self, q: str, k: int = 10, rho: int = 10) -> List[JASSv2SearcherResul self.object.set_top_k(k) self.object.set_postings_to_process(rho) - if q[0].isdigit() and q[1] == ':': + # JASS expects queries to be an identifier followed by terms, delimited by either ':', '\t', or ' ' + # We do not want to split on spaces as it may result in discarded terms. + split_query = q.split(":\t") + # Assume the first field is the identifier... + if len(split_query) == 2: results = self.object.search(q) else: - results = self.object.search("0:"+q) # appending "0: to handle jass' requirements" + results = self.object.search("0:"+q) # appending `0:` so JASS consumes it as the identifier return (self.convert_to_search_result(results.results_list)) - def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.JASS_string_vector: + def __list_to_strvector(self, qids: List[str] ,queries: List[str]) -> pyjass.JASS_string_vector: """Convert a list of queries to a c++ string_vector. Parameters @@ -126,7 +138,7 @@ def __list_to_strvector(self,qids: List[str],queries: List[str]) -> pyjass.JASS_ - def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = 10, threads: int = 1) -> Dict[str, List[JASSv2SearcherResult]]: + def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: int = ONE_BILLION, threads: int = 1) -> Dict[str, List[JASSv2SearcherResult]]: """Search the collection concurrently for multiple queries, using multiple threads. From ca2d306ca244d2e9b4cb99992f96017d44cf6e43 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 22:06:15 +0000 Subject: [PATCH 19/35] implemented from_prebuilt_index from lucene --- pyserini/search/jass/_searcher.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index cb9fea74e..27d4bf87c 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -24,6 +24,7 @@ import pyjass from typing import Dict, List, Optional, Union from pyserini.trectools import TrecRun +from pyserini.util import download_prebuilt_index logger = logging.getLogger(__name__) # Wrappers around JASS classes @@ -59,6 +60,31 @@ def __init__(self, index_dir: str, version: int = 2): index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) + + + @classmethod + def from_prebuilt_index(cls, prebuilt_index_name: str): + """Build a searcher from a pre-built index; download the index if necessary. + + Parameters + ---------- + prebuilt_index_name : str + Prebuilt index name. + + Returns + ------- + SimpleSearcher + Searcher built from the prebuilt index. + """ + print(f'Attempting to initialize pre-built index {prebuilt_index_name}.') + try: + index_dir = download_prebuilt_index(prebuilt_index_name) + except ValueError as e: + print(str(e)) + return None + + print(f'Initializing {prebuilt_index_name}...') + return cls(index_dir) def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult]: From 52f077a873ad7d6f39ce1b0350eab240622f1ff8 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Wed, 16 Feb 2022 22:50:58 +0000 Subject: [PATCH 20/35] exposed ascii/query parser to main_driver --- pyserini/search/jass/__main__.py | 21 ++++++++++++++------- pyserini/search/jass/_searcher.py | 9 ++++++++- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index 6d159a6fb..e35a48dca 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -14,21 +14,22 @@ # limitations under the License. # -import pyjass +from pyserini.search import JASSv2Searcher import argparse import os -import errno from tqdm import tqdm from pyserini.output_writer import OutputFormat, get_output_writer from pyserini.query_iterator import get_query_iterator, TopicsFormat -from pyserini.search import JASSv2Searcher + def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to pyJass index") parser.add_argument('--rho', type=int, default=1000000000, help='rho: how many postings to process') + parser.add_argument('--ascii', default=True, action='store_true', help="Use ASCII parser") + parser.add_argument('--query', action='store_true', help="Use Query Parser") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search a pyJass index.') @@ -53,11 +54,9 @@ def define_search_args(parser): topics = query_iterator.topics if os.path.exists(args.index): - # create searcher from index directory searcher = JASSv2Searcher(args.index, 2) else: - # TODO: handle pre_build index if it's not found but we will throw file not found - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), args.index) + searcher = JASSv2Searcher.from_prebuilt_index(args.index) if not searcher: exit() @@ -65,6 +64,14 @@ def define_search_args(parser): # JASS does not (yet) support field-based retrieval fields = None + # JASS Parser Option + if args.ascii: + searcher.set_ascii() + if args.query: + searcher.set_query() + + + # build output path output_path = args.output if output_path is None: @@ -72,7 +79,7 @@ def define_search_args(parser): output_path = '.'.join(tokens) print(f'Running {args.topics} topics, saving to {output_path}...') - tag = output_path[:-4] if args.output is None else 'pyJass' + tag = output_path[:-4] if args.output is None else 'JaSS' output_writer = get_output_writer(output_path, OutputFormat(args.output_format), 'w', max_hits=args.hits, tag=tag, topics=topics) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 27d4bf87c..3eb853760 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -111,7 +111,6 @@ def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult return docid_score_pair - def search(self, q: str, k: int = 10, rho: int = ONE_BILLION) -> List[JASSv2SearcherResult]: """Search the collection for a single query. @@ -198,3 +197,11 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in output[key] = self.convert_to_search_result(results[i].results[key].results_list) return output + + def set_ascii(self) -> None: + """Set Jass to use ascii parser.""" + self.object.use_ascii_parser() + + def set_query(self) -> None: + """Set Jass to use query parser.""" + self.object.use_query_parser() \ No newline at end of file From 67cad079f8890e5fda5bb326541d944b8eeab669 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Thu, 17 Feb 2022 10:29:14 +1100 Subject: [PATCH 21/35] First step for prepbuilt --- pyserini/prebuilt_index_info.py | 112 ++++++++++++++++++++++++++++++++ pyserini/util.py | 10 ++- 2 files changed, 119 insertions(+), 3 deletions(-) diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index 3e7f91850..47b9d42ce 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1449,3 +1449,115 @@ "texts": "cast2019" } } + "msmarco-passage-deepimpact": { + "description": "Lucene impact index of the MS MARCO passage corpus encoded by DeepImpact", + "filename": "lucene-index.msmarco-passage.deepimpact.20211012.58d286.tar.gz", + "readme": "lucene-index.msmarco-passage.deepimpact.20211012.58d286.readme.txt", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/lucene-index.msmarco-passage.deepimpact.20211012.58d286.tar.gz", + "https://vault.cs.uwaterloo.ca/s/FfwF6nB9M5sjTYk/download", + ], + "md5": "9938f5529fee5cdb405b8587746c9e93", + "size compressed (bytes)": 1295216704, + "total_terms": 35455908214, + "documents": 8841823, + "unique_terms": 3514102, + "downloaded": False + }, + "msmarco-passage-unicoil-d2q": { + "description": "Lucene impact index of the MS MARCO passage corpus encoded by uniCOIL-d2q", + "filename": "lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.tar.gz", + "readme": "lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.readme.txt", + "urls": [ + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.tar.gz", + "https://vault.cs.uwaterloo.ca/s/LGoAAXM7ZEbyQ7y/download" + ], + "md5": "4a8cb3b86a0d9085a0860c7f7bb7fe99", + "size compressed (bytes)": 1205104390, + "total_terms": 44495093768, + "documents": 8841823, + "unique_terms": 27678, + "downloaded": False + }, + +JASS_INDEX_INFO = { + "msmarco-passage-bm25": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring", + "filename": "jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "c6168b6ea661e06f60e9937eb76dd01a", + "size compressed (bytes)": 710697575, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "msmarco-passage-d2q-t5": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring over a DocT5Query expanded collection", + "filename": "jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "b1c60fcef315890aa0c99fb71cb6aae9", + "size compressed (bytes)": 929806687, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "msmarco-passage-deepimpact": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with DeepImpact scoring", + "filename": "jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "d4cd22ef82d27956c9fcd32ebc0fd77b", + "size compressed (bytes)": 1217477634, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "msmarco-passage-unicoil-d2q": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a DocT5Query expanded collection", + "filename" : "jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "87d8a372dc268ad8ca259492e55d7528", + "size compressed (bytes)": 1084195359, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "msmarco-unicoil-tilde": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a TILDE expanded collection", + "filename": "jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "78033304c7b1d781b9d015a716e33ba4", + "size compressed (bytes)": 1724440877, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + }, + "msmarco-passage-distill-splade-max": { + "description": "BP reordered JASS impact index of the MS MARCO passage corpus with distill-splade-max scoring", + "filename": "jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz", + "urls": [ + + ], + "md5": "003d1fd3a02ab35dee5e5e2949e51752", + "size compressed (bytes)": 3530600632, + "total_terms": 0, + "documents": 0, + "unique_terms": 0, + "downloaded": False + } +} + diff --git a/pyserini/util.py b/pyserini/util.py index b6b11dac5..9f42fc733 100644 --- a/pyserini/util.py +++ b/pyserini/util.py @@ -27,7 +27,7 @@ from pyserini.encoded_query_info import QUERY_INFO from pyserini.evaluate_script_info import EVALUATION_INFO -from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO +from pyserini.prebuilt_index_info import TF_INDEX_INFO, FAISS_INDEX_INFO, IMPACT_INDEX_INFO, JASS_INDEX_INFO # https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5 @@ -162,8 +162,10 @@ def check_downloaded(index_name): target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - else: + elif: target_index = FAISS_INDEX_INFO[index_name] + else: + target_index = JASS_INDEX_INFO[index_name] index_url = target_index['urls'][0] index_md5 = target_index['md5'] index_name = index_url.split('/')[-1] @@ -211,8 +213,10 @@ def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - else: + elif: target_index = FAISS_INDEX_INFO[index_name] + else: + target_index = JASS_INDEX_INFO[index_name] index_md5 = target_index['md5'] for url in target_index['urls']: local_filename = target_index['filename'] if 'filename' in target_index else None From 2cffa6fb0329ffb78bd1da1f93c5a3381b355a2d Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Thu, 17 Feb 2022 11:00:25 +1100 Subject: [PATCH 22/35] Make default query parser the default, ascii as an option. Fix prebuilt. --- pyserini/prebuilt_index_info.py | 38 ++++--------------------------- pyserini/search/jass/__main__.py | 9 ++------ pyserini/search/jass/_searcher.py | 7 +++--- pyserini/util.py | 4 ++-- 4 files changed, 12 insertions(+), 46 deletions(-) diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index 47b9d42ce..ad39c66d5 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1449,36 +1449,6 @@ "texts": "cast2019" } } - "msmarco-passage-deepimpact": { - "description": "Lucene impact index of the MS MARCO passage corpus encoded by DeepImpact", - "filename": "lucene-index.msmarco-passage.deepimpact.20211012.58d286.tar.gz", - "readme": "lucene-index.msmarco-passage.deepimpact.20211012.58d286.readme.txt", - "urls": [ - "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/lucene-index.msmarco-passage.deepimpact.20211012.58d286.tar.gz", - "https://vault.cs.uwaterloo.ca/s/FfwF6nB9M5sjTYk/download", - ], - "md5": "9938f5529fee5cdb405b8587746c9e93", - "size compressed (bytes)": 1295216704, - "total_terms": 35455908214, - "documents": 8841823, - "unique_terms": 3514102, - "downloaded": False - }, - "msmarco-passage-unicoil-d2q": { - "description": "Lucene impact index of the MS MARCO passage corpus encoded by uniCOIL-d2q", - "filename": "lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.tar.gz", - "readme": "lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.readme.txt", - "urls": [ - "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/lucene-index.msmarco-passage.unicoil-d2q.20211012.58d286.tar.gz", - "https://vault.cs.uwaterloo.ca/s/LGoAAXM7ZEbyQ7y/download" - ], - "md5": "4a8cb3b86a0d9085a0860c7f7bb7fe99", - "size compressed (bytes)": 1205104390, - "total_terms": 44495093768, - "documents": 8841823, - "unique_terms": 27678, - "downloaded": False - }, JASS_INDEX_INFO = { "msmarco-passage-bm25": { @@ -1487,8 +1457,8 @@ "urls": [ ], - "md5": "c6168b6ea661e06f60e9937eb76dd01a", - "size compressed (bytes)": 710697575, + "md5": "0241d6797567eec8c333187f8fa37aa3", + "size compressed (bytes)": 629101230, "total_terms": 0, "documents": 0, "unique_terms": 0, @@ -1500,8 +1470,8 @@ "urls": [ ], - "md5": "b1c60fcef315890aa0c99fb71cb6aae9", - "size compressed (bytes)": 929806687, + "md5": "7efe0e746c552b73c31869e0b0bd6837", + "size compressed (bytes)": 832303111, "total_terms": 0, "documents": 0, "unique_terms": 0, diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index e35a48dca..010990621 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -28,8 +28,7 @@ def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to pyJass index") parser.add_argument('--rho', type=int, default=1000000000, help='rho: how many postings to process') - parser.add_argument('--ascii', default=True, action='store_true', help="Use ASCII parser") - parser.add_argument('--query', action='store_true', help="Use Query Parser") + parser.add_argument('--ascii', default=False, action='store_true', help="Use ASCII parser") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search a pyJass index.') @@ -66,11 +65,7 @@ def define_search_args(parser): # JASS Parser Option if args.ascii: - searcher.set_ascii() - if args.query: - searcher.set_query() - - + searcher.set_ascii_parser() # build output path output_path = args.output diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 3eb853760..5b96ec112 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -57,6 +57,7 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() + self.set_default_parser() index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) @@ -198,10 +199,10 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in return output - def set_ascii(self) -> None: + def set_ascii_parser(self) -> None: """Set Jass to use ascii parser.""" self.object.use_ascii_parser() - def set_query(self) -> None: + def set_default_parser(self) -> None: """Set Jass to use query parser.""" - self.object.use_query_parser() \ No newline at end of file + self.object.use_query_parser() diff --git a/pyserini/util.py b/pyserini/util.py index 9f42fc733..b8a76c188 100644 --- a/pyserini/util.py +++ b/pyserini/util.py @@ -162,7 +162,7 @@ def check_downloaded(index_name): target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - elif: + elif index_name in FAISS_INDEX_INFO: target_index = FAISS_INDEX_INFO[index_name] else: target_index = JASS_INDEX_INFO[index_name] @@ -213,7 +213,7 @@ def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): target_index = TF_INDEX_INFO[index_name] elif index_name in IMPACT_INDEX_INFO: target_index = IMPACT_INDEX_INFO[index_name] - elif: + elif index_name in FAISS_INDEX_INFO: target_index = FAISS_INDEX_INFO[index_name] else: target_index = JASS_INDEX_INFO[index_name] From a4593ccbd7e530a19bb86ed314ed47f05219695f Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 01:10:05 +0000 Subject: [PATCH 23/35] add to export time functionality --- pyserini/search/jass/_searcher.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 3eb853760..35db7a6f1 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -57,6 +57,7 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() + #self_getTime = None index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) @@ -204,4 +205,17 @@ def set_ascii(self) -> None: def set_query(self) -> None: """Set Jass to use query parser.""" - self.object.use_query_parser() \ No newline at end of file + self.object.use_query_parser() + + def __get_time_taken(self) -> float: + """Get the time taken to perform the search.' + Returns + ------- + float + Time taken to perform the search. + """ + raise NotImplementedError("This method is not implemented in JASSv2Searcher.") + + + + \ No newline at end of file From 6a052ac98cfac9428576ce8a7da9183ad32c4e16 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Thu, 17 Feb 2022 14:20:11 +1100 Subject: [PATCH 24/35] Prefix 'jass' to prebuilt index strings --- pyserini/prebuilt_index_info.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index ad39c66d5..d4659eed8 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1451,7 +1451,7 @@ } JASS_INDEX_INFO = { - "msmarco-passage-bm25": { + "jass-msmarco-passage-bm25": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring", "filename": "jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz", "urls": [ @@ -1464,7 +1464,7 @@ "unique_terms": 0, "downloaded": False }, - "msmarco-passage-d2q-t5": { + "jass-msmarco-passage-d2q-t5": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring over a DocT5Query expanded collection", "filename": "jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz", "urls": [ @@ -1477,7 +1477,7 @@ "unique_terms": 0, "downloaded": False }, - "msmarco-passage-deepimpact": { + "jass-msmarco-passage-deepimpact": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with DeepImpact scoring", "filename": "jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz", "urls": [ @@ -1490,7 +1490,7 @@ "unique_terms": 0, "downloaded": False }, - "msmarco-passage-unicoil-d2q": { + "jass-msmarco-passage-unicoil-d2q": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a DocT5Query expanded collection", "filename" : "jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz", "urls": [ @@ -1503,7 +1503,7 @@ "unique_terms": 0, "downloaded": False }, - "msmarco-unicoil-tilde": { + "jass-msmarco-unicoil-tilde": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a TILDE expanded collection", "filename": "jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz", "urls": [ @@ -1516,7 +1516,7 @@ "unique_terms": 0, "downloaded": False }, - "msmarco-passage-distill-splade-max": { + "jass-msmarco-passage-distill-splade-max": { "description": "BP reordered JASS impact index of the MS MARCO passage corpus with distill-splade-max scoring", "filename": "jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz", "urls": [ From 6fe29a4a862f46ce50ca6a8139eeafcaf4a48995 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Thu, 17 Feb 2022 15:18:47 +1100 Subject: [PATCH 25/35] Revert some parameters to align with the Lucene API --- pyserini/search/jass/__main__.py | 11 ++++++++--- pyserini/search/jass/_searcher.py | 4 ++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pyserini/search/jass/__main__.py b/pyserini/search/jass/__main__.py index 010990621..a2dbd16a1 100644 --- a/pyserini/search/jass/__main__.py +++ b/pyserini/search/jass/__main__.py @@ -28,7 +28,7 @@ def define_search_args(parser): parser.add_argument('--index', type=str, metavar='path to index or index name', required=True, help="Path to pyJass index") parser.add_argument('--rho', type=int, default=1000000000, help='rho: how many postings to process') - parser.add_argument('--ascii', default=False, action='store_true', help="Use ASCII parser") + parser.add_argument('--basic-parser', default=False, action='store_true', help="Use the basic query parser; Default is to use the ASCII parser") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Search a pyJass index.') @@ -47,6 +47,8 @@ def define_search_args(parser): default=1, help="Specify batch size to search the collection concurrently.") parser.add_argument('--threads', type=int, metavar='num', required=False, default=1, help="Maximum number of threads to use.") + parser.add_argument('--impact', action='store_true', help="Use Impact.") + args = parser.parse_args() query_iterator = get_query_iterator(args.topics, TopicsFormat(args.topics_format)) @@ -63,9 +65,12 @@ def define_search_args(parser): # JASS does not (yet) support field-based retrieval fields = None + if not args.impact: + print("Enforcing --impact; JASS requires impact-based retrieval.") + # JASS Parser Option - if args.ascii: - searcher.set_ascii_parser() + if args.basic_parser: + searcher.set_basic_parser() # build output path output_path = args.output diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index d71718343..b7965b301 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -57,7 +57,7 @@ class JASSv2Searcher: def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() - self.set_default_parser() + self.set_ascii_parser() index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) @@ -203,7 +203,7 @@ def set_ascii_parser(self) -> None: """Set Jass to use ascii parser.""" self.object.use_ascii_parser() - def set_default_parser(self) -> None: + def set_basic_parser(self) -> None: """Set Jass to use query parser.""" self.object.use_query_parser() From abc3fd99f876a53afa6f5f617bf2e040eb9b15e9 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Thu, 17 Feb 2022 15:46:18 +1100 Subject: [PATCH 26/35] More fixed for pre-built. Not quite ready yet... --- pyserini/prebuilt_index_info.py | 14 +++++++------- pyserini/util.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index d4659eed8..24af2a6b4 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1455,7 +1455,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring", "filename": "jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz" ], "md5": "0241d6797567eec8c333187f8fa37aa3", "size compressed (bytes)": 629101230, @@ -1468,7 +1468,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with BM25 scoring over a DocT5Query expanded collection", "filename": "jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz" ], "md5": "7efe0e746c552b73c31869e0b0bd6837", "size compressed (bytes)": 832303111, @@ -1481,7 +1481,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with DeepImpact scoring", "filename": "jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz" ], "md5": "d4cd22ef82d27956c9fcd32ebc0fd77b", "size compressed (bytes)": 1217477634, @@ -1494,7 +1494,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a DocT5Query expanded collection", "filename" : "jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz" ], "md5": "87d8a372dc268ad8ca259492e55d7528", "size compressed (bytes)": 1084195359, @@ -1507,7 +1507,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with uniCOIL scoring over a TILDE expanded collection", "filename": "jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz" ], "md5": "78033304c7b1d781b9d015a716e33ba4", "size compressed (bytes)": 1724440877, @@ -1520,7 +1520,7 @@ "description": "BP reordered JASS impact index of the MS MARCO passage corpus with distill-splade-max scoring", "filename": "jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz", "urls": [ - + "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz" ], "md5": "003d1fd3a02ab35dee5e5e2949e51752", "size compressed (bytes)": 3530600632, @@ -1530,4 +1530,4 @@ "downloaded": False } } - + diff --git a/pyserini/util.py b/pyserini/util.py index b8a76c188..88ba3f8ff 100644 --- a/pyserini/util.py +++ b/pyserini/util.py @@ -207,7 +207,7 @@ def get_dense_indexes_info(): def download_prebuilt_index(index_name, force=False, verbose=True, mirror=None): - if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO: + if index_name not in TF_INDEX_INFO and index_name not in FAISS_INDEX_INFO and index_name not in IMPACT_INDEX_INFO and index_name not in JASS_INDEX_INFO: raise ValueError(f'Unrecognized index name {index_name}') if index_name in TF_INDEX_INFO: target_index = TF_INDEX_INFO[index_name] From 695806f2fc9ed1fe8a1f65e1934523d026d7a58f Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 20:49:07 +0000 Subject: [PATCH 27/35] first run of test_cases --- pyserini/search/jass/_searcher.py | 1 + tests/test_search_pyjass.py | 262 ++++++++++++++++++++++++++++++ 2 files changed, 263 insertions(+) create mode 100644 tests/test_search_pyjass.py diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index b7965b301..f448de9d2 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -58,6 +58,7 @@ def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() self.set_ascii_parser() + self.num_docs = self.object.get_document_count() index = self.object.load_index(version,index_dir) if index != 0: raise Exception('Unable to load index - error code' + str(index)) diff --git a/tests/test_search_pyjass.py b/tests/test_search_pyjass.py new file mode 100644 index 000000000..14f26a727 --- /dev/null +++ b/tests/test_search_pyjass.py @@ -0,0 +1,262 @@ +# +# Pyserini: Reproducible IR research with sparse and dense representations +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import shutil +import tarfile +import unittest +from random import randint +from typing import List, Dict +from urllib.request import urlretrieve + +from pyserini.search.jass import JASSv2Searcher, JASSv2SearcherResult +import pyjass +from pyserini.index import Document + + +class TestSearchPyJass(unittest.TestCase): + def setUp(self): + # Download pre-built CACM index; append a random value to avoid filename clashes. + #TODO To-be filled in by the test runner. + r = randint(0, 10000000) + self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz' # to be replaced + self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) + self.index_dir = 'index{}/'.format(r) + + filename, headers = urlretrieve(self.collection_url, self.tarball_name) + + tarball = tarfile.open(self.tarball_name) + tarball.extractall(self.index_dir) + tarball.close() + + self.searcher = JASSv2Searcher(f'{self.index_dir}lucene-index.cacm') + + def test_basic(self): + hits = self.searcher.search('information retrieval') + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 4.76550, places=5) + + + + self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-2516') + self.assertAlmostEqual(hits[9].score, 4.21740, places=5) + + hits = self.searcher.search('search') + + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[0].docid, 'CACM-3058') + self.assertAlmostEqual(hits[0].score, 2.85760, places=5) + + self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-3040') + self.assertAlmostEqual(hits[9].score, 2.68780, places=5) + + def test_batch(self): + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], threads=2) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) + self.assertEqual(results['q1'][0].docid, 'CACM-3134') + self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5) + + self.assertTrue(isinstance(results['q1'][9], JASSv2SearcherResult)) + self.assertEqual(results['q1'][9].docid, 'CACM-2516') + self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5) + + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) + self.assertEqual(results['q2'][0].docid, 'CACM-3058') + self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5) + + self.assertTrue(isinstance(results['q2'][9], JASSv2SearcherResult)) + self.assertEqual(results['q2'][9].docid, 'CACM-3040') + self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5) + + def test_basic_k(self): + hits = self.searcher.search('information retrieval', k=100) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(len(hits), 100) + + def test_batch_k(self): + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=100, threads=2) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q1']), 100) + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) + self.assertEqual(len(results['q2']), 100) + + def test_basic_rho(self): + # This test just provides a sanity check, it's not that interesting as it only searches one field. + hits = self.searcher.search('information retrieval', k=42, fields={'contents': 2.0}) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(hits, List)) + self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) + self.assertEqual(len(hits), 42) + + def test_batch_rho(self): + # This test just provides a sanity check, it's not that interesting as it only searches one field. + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=42, + threads=2, fields={'contents': 2.0}) + + self.assertEqual(3204, self.searcher.num_docs) + self.assertTrue(isinstance(results, Dict)) + self.assertTrue(isinstance(results['q1'], List)) + self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) + self.assertEqual(len(results['q1']), 42) + self.assertTrue(isinstance(results['q2'], List)) + self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) + self.assertEqual(len(results['q2']), 42) + + def test_different_similarity(self): + # qld, default mu + self.searcher.set_qld() + self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet')) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 3.68030, places=5) + self.assertEqual(hits[9].docid, 'CACM-1927') + self.assertAlmostEqual(hits[9].score, 2.53240, places=5) + + # bm25, default parameters + self.searcher.set_bm25() + self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25')) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 4.76550, places=5) + self.assertEqual(hits[9].docid, 'CACM-2516') + self.assertAlmostEqual(hits[9].score, 4.21740, places=5) + + # qld, custom mu + self.searcher.set_qld(100) + self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet')) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 6.35580, places=5) + self.assertEqual(hits[9].docid, 'CACM-2631') + self.assertAlmostEqual(hits[9].score, 5.18960, places=5) + + # bm25, custom parameters + self.searcher.set_bm25(0.8, 0.3) + self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25')) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 4.86880, places=5) + self.assertEqual(hits[9].docid, 'CACM-2516') + self.assertAlmostEqual(hits[9].score, 4.33320, places=5) + + def test_rm3(self): + self.searcher.set_rm3() + self.assertTrue(self.searcher.is_using_rm3()) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 2.18010, places=5) + self.assertEqual(hits[9].docid, 'CACM-2516') + self.assertAlmostEqual(hits[9].score, 1.70330, places=5) + + self.searcher.unset_rm3() + self.assertFalse(self.searcher.is_using_rm3()) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 4.76550, places=5) + self.assertEqual(hits[9].docid, 'CACM-2516') + self.assertAlmostEqual(hits[9].score, 4.21740, places=5) + + self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3) + self.assertTrue(self.searcher.is_using_rm3()) + + hits = self.searcher.search('information retrieval') + + self.assertEqual(hits[0].docid, 'CACM-3134') + self.assertAlmostEqual(hits[0].score, 2.17190, places=5) + self.assertEqual(hits[9].docid, 'CACM-1457') + self.assertAlmostEqual(hits[9].score, 1.43700, places=5) + + def test_ascii(self): + raise NotImplementedError + + + def test_basicparser(self): + raise NotImplementedError + + def test_doc_str(self): + # The doc method is overloaded: if input is str, it's assumed to be an external collection docid. + doc = self.searcher.doc('CACM-0002') + self.assertTrue(isinstance(doc, Document)) + + # These are all equivalent ways to get the docid. + self.assertEqual(doc.lucene_document().getField('id').stringValue(), 'CACM-0002') + self.assertEqual(doc.id(), 'CACM-0002') + self.assertEqual(doc.docid(), 'CACM-0002') + self.assertEqual(doc.get('id'), 'CACM-0002') + + # These are all equivalent ways to get the 'raw' field + self.assertEqual(186, len(doc.raw())) + self.assertEqual(186, len(doc.get('raw'))) + self.assertEqual(186, len(doc.lucene_document().get('raw'))) + self.assertEqual(186, len(doc.lucene_document().getField('raw').stringValue())) + + # These are all equivalent ways to get the 'contents' field + self.assertEqual(154, len(doc.contents())) + self.assertEqual(154, len(doc.get('contents'))) + self.assertEqual(154, len(doc.lucene_document().get('contents'))) + self.assertEqual(154, len(doc.lucene_document().getField('contents').stringValue())) + + # Should return None if we request a docid that doesn't exist + self.assertTrue(self.searcher.doc('foo') is None) + + def test_doc_by_field(self): + self.assertEqual(self.searcher.doc('CACM-3134').docid(), + self.searcher.doc_by_field('id', 'CACM-3134').docid()) + + # Should return None if we request a docid that doesn't exist + self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None) + + def tearDown(self): + os.remove(self.tarball_name) + shutil.rmtree(self.index_dir) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 64e20a6451e14f1d2d9bc77569b222d6c0b2747b Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 21:22:50 +0000 Subject: [PATCH 28/35] updated searcher --- pyserini/search/jass/_searcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index f448de9d2..d880efd5c 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -103,7 +103,7 @@ def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult List of JASSv2SearcherResult which contains the DocID and also the score pair. """ docid_score_pair = list() - results = result_list.split('\n') + results = result_list.splitlines() for res in results: # Split by space. We expect the `trec` format, bail out if we don't get it result_data = res.split(' ') From 0327cccc0c2e37ea2be10fceafc5d81bdefbd125 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 22:51:14 +0000 Subject: [PATCH 29/35] make it go fasterrrrrr!!!! --- pyserini/search/jass/_searcher.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index d880efd5c..67eefea76 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -103,7 +103,7 @@ def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult List of JASSv2SearcherResult which contains the DocID and also the score pair. """ docid_score_pair = list() - results = result_list.splitlines() + results = result_list.splitlines() #using split lines to split it faster for res in results: # Split by space. We expect the `trec` format, bail out if we don't get it result_data = res.split(' ') @@ -161,7 +161,7 @@ def __list_to_strvector(self, qids: List[str] ,queries: List[str]) -> pyjass.JAS c++ string_vector to be consumed by Jass. """ - return(pyjass.JASS_string_vector([str(x[0] + ":") + x[1] for x in zip(qids, queries)])) + return(pyjass.JASS_string_vector([str(x[0].join([":",x[1]])) for x in zip(qids, queries)])) From a06a640c250de463a5b9ccf9f0af99051a14f8ab Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 23:17:25 +0000 Subject: [PATCH 30/35] fix the bug --- pyserini/search/jass/_searcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 67eefea76..b8b3d1b73 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -103,7 +103,7 @@ def convert_to_search_result(self, result_list:str) -> List[JASSv2SearcherResult List of JASSv2SearcherResult which contains the DocID and also the score pair. """ docid_score_pair = list() - results = result_list.splitlines() #using split lines to split it faster + results = result_list.split('\n') for res in results: # Split by space. We expect the `trec` format, bail out if we don't get it result_data = res.split(' ') From 3b340a3b9a4661a70cadc97937dd63d3a94816d5 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Thu, 17 Feb 2022 23:36:58 +0000 Subject: [PATCH 31/35] using list comprehension --- pyserini/search/jass/_searcher.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index b8b3d1b73..4ff3d50c1 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -161,7 +161,7 @@ def __list_to_strvector(self, qids: List[str] ,queries: List[str]) -> pyjass.JAS c++ string_vector to be consumed by Jass. """ - return(pyjass.JASS_string_vector([str(x[0].join([":",x[1]])) for x in zip(qids, queries)])) + return(pyjass.JASS_string_vector([':'.join(map(str, i)) for i in zip(qids, queries)])) From cf2453b560ecd91d1c8c25ae045f58df9d2cb043 Mon Sep 17 00:00:00 2001 From: Joel Mackenzie Date: Fri, 18 Feb 2022 15:13:39 +1100 Subject: [PATCH 32/35] Add new pre-built index hashes --- pyserini/prebuilt_index_info.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pyserini/prebuilt_index_info.py b/pyserini/prebuilt_index_info.py index 318edc83c..e7bba38eb 100644 --- a/pyserini/prebuilt_index_info.py +++ b/pyserini/prebuilt_index_info.py @@ -1629,7 +1629,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.bm25.20220217.5cbb40.tar.gz" ], - "md5": "0241d6797567eec8c333187f8fa37aa3", + "md5": "9add4b1f754c5f33d31501c65e5e92d3", "size compressed (bytes)": 629101230, "total_terms": 0, "documents": 0, @@ -1642,7 +1642,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.d2q-t5.20220217.5cbb40.tar.gz" ], - "md5": "7efe0e746c552b73c31869e0b0bd6837", + "md5": "9be8d8890d60410243a8c7323849ecc9", "size compressed (bytes)": 832303111, "total_terms": 0, "documents": 0, @@ -1655,7 +1655,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.deepimpact.20220217.5cbb40.tar.gz" ], - "md5": "d4cd22ef82d27956c9fcd32ebc0fd77b", + "md5": "d9ed05d97e1f07373d7a98a1dd9f6fac", "size compressed (bytes)": 1217477634, "total_terms": 0, "documents": 0, @@ -1668,7 +1668,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-d2q.20220217.5cbb40.tar.gz" ], - "md5": "87d8a372dc268ad8ca259492e55d7528", + "md5": "24bab2ef23914ab124d4f0eba8dc866c", "size compressed (bytes)": 1084195359, "total_terms": 0, "documents": 0, @@ -1681,7 +1681,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.unicoil-tilde.20220217.5cbb40.tar.gz" ], - "md5": "78033304c7b1d781b9d015a716e33ba4", + "md5": "705c3e72cff189265de9b5c509be00a6", "size compressed (bytes)": 1724440877, "total_terms": 0, "documents": 0, @@ -1694,7 +1694,7 @@ "urls": [ "https://rgw.cs.uwaterloo.ca/JIMMYLIN-bucket0/pyserini-indexes/jass-index.msmarco-passage.distill-splade-max.20220217.5cbb40.tar.gz" ], - "md5": "003d1fd3a02ab35dee5e5e2949e51752", + "md5": "f6bf3cdf983d4e1aaee8677acbcdb47f", "size compressed (bytes)": 3530600632, "total_terms": 0, "documents": 0, From 0b780cbc1d8092aec552b363fa90976098fba275 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Fri, 18 Feb 2022 10:40:24 +0000 Subject: [PATCH 33/35] updated unit test for pyjass --- pyserini/search/jass/_searcher.py | 10 +- tests/test_search_pyjass.py | 176 +++++++----------------------- 2 files changed, 45 insertions(+), 141 deletions(-) diff --git a/pyserini/search/jass/_searcher.py b/pyserini/search/jass/_searcher.py index 4ff3d50c1..896a78ab8 100644 --- a/pyserini/search/jass/_searcher.py +++ b/pyserini/search/jass/_searcher.py @@ -58,8 +58,8 @@ def __init__(self, index_dir: str, version: int = 2): self.index_dir = index_dir self.object = pyjass.anytime() self.set_ascii_parser() - self.num_docs = self.object.get_document_count() index = self.object.load_index(version,index_dir) + self.num_docs = self.object.get_document_count() if index != 0: raise Exception('Unable to load index - error code' + str(index)) @@ -200,13 +200,13 @@ def batch_search(self, queries: List[str], qids: List[str], k: int = 10, rho: in return output - def set_ascii_parser(self) -> None: + def set_ascii_parser(self) -> int: """Set Jass to use ascii parser.""" - self.object.use_ascii_parser() + return(self.object.use_ascii_parser()) - def set_basic_parser(self) -> None: + def set_basic_parser(self) -> int: """Set Jass to use query parser.""" - self.object.use_query_parser() + return(self.object.use_query_parser()) def __get_time_taken(self) -> float: diff --git a/tests/test_search_pyjass.py b/tests/test_search_pyjass.py index 14f26a727..08b61da97 100644 --- a/tests/test_search_pyjass.py +++ b/tests/test_search_pyjass.py @@ -32,9 +32,9 @@ def setUp(self): # Download pre-built CACM index; append a random value to avoid filename clashes. #TODO To-be filled in by the test runner. r = randint(0, 10000000) - self.collection_url = 'https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.tar.gz' # to be replaced - self.tarball_name = 'lucene-index.cacm-{}.tar.gz'.format(r) - self.index_dir = 'index{}/'.format(r) + self.collection_url = 'https://github.com/prasys/anserini-data/raw/master/CACM/jass-index.cacm.tar.gz' # to be replaced + self.tarball_name = 'jass-index.cacm-{}.tar.gz'.format(r) + self.index_dir = 'jass{}/'.format(r) filename, headers = urlretrieve(self.collection_url, self.tarball_name) @@ -42,7 +42,7 @@ def setUp(self): tarball.extractall(self.index_dir) tarball.close() - self.searcher = JASSv2Searcher(f'{self.index_dir}lucene-index.cacm') + self.searcher = JASSv2Searcher(f'{self.index_dir}jass-index.cacm') def test_basic(self): hits = self.searcher.search('information retrieval') @@ -52,23 +52,23 @@ def test_basic(self): self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 4.76550, places=5) + self.assertEqual(hits[0].score, 664.0) self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) - self.assertEqual(hits[9].docid, 'CACM-2516') - self.assertAlmostEqual(hits[9].score, 4.21740, places=5) + self.assertEqual(hits[9].docid, 'CACM-2631') + self.assertEqual(hits[9].score, 589.0) hits = self.searcher.search('search') self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) - self.assertEqual(hits[0].docid, 'CACM-3058') - self.assertAlmostEqual(hits[0].score, 2.85760, places=5) + self.assertEqual(hits[0].docid, 'CACM-3041') + self.assertEqual(hits[0].score, 413.0) self.assertTrue(isinstance(hits[9], JASSv2SearcherResult)) - self.assertEqual(hits[9].docid, 'CACM-3040') - self.assertAlmostEqual(hits[9].score, 2.68780, places=5) + self.assertEqual(hits[9].docid, 'CACM-1815') + self.assertEqual(hits[9].score, 392.0) def test_batch(self): results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], threads=2) @@ -79,20 +79,20 @@ def test_batch(self): self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) self.assertEqual(results['q1'][0].docid, 'CACM-3134') - self.assertAlmostEqual(results['q1'][0].score, 4.76550, places=5) + self.assertEqual(results['q1'][0].score, 664.0) self.assertTrue(isinstance(results['q1'][9], JASSv2SearcherResult)) - self.assertEqual(results['q1'][9].docid, 'CACM-2516') - self.assertAlmostEqual(results['q1'][9].score, 4.21740, places=5) + self.assertEqual(results['q1'][9].docid, 'CACM-2631') + self.assertEqual(results['q1'][9].score, 589.0) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) - self.assertEqual(results['q2'][0].docid, 'CACM-3058') - self.assertAlmostEqual(results['q2'][0].score, 2.85760, places=5) + self.assertEqual(results['q2'][0].docid, 'CACM-3041') + self.assertEqual(results['q2'][0].score, 413.0) self.assertTrue(isinstance(results['q2'][9], JASSv2SearcherResult)) - self.assertEqual(results['q2'][9].docid, 'CACM-3040') - self.assertAlmostEqual(results['q2'][9].score, 2.68780, places=5) + self.assertEqual(results['q2'][9].docid, 'CACM-1815') + self.assertEqual(results['q2'][9].score, 392.0) def test_basic_k(self): hits = self.searcher.search('information retrieval', k=100) @@ -112,146 +112,50 @@ def test_batch_k(self): self.assertEqual(len(results['q1']), 100) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) - self.assertEqual(len(results['q2']), 100) + self.assertEqual(len(results['q2']), 99) def test_basic_rho(self): - # This test just provides a sanity check, it's not that interesting as it only searches one field. - hits = self.searcher.search('information retrieval', k=42, fields={'contents': 2.0}) + hits = self.searcher.search('information retrieval', k=42, rho=50) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) - self.assertTrue(isinstance(hits[0], JSimpleSearcherResult)) + self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) + self.assertEqual(hits[9].docid, 'CACM-1725') + self.assertEqual(hits[9].score, 362.0) self.assertEqual(len(hits), 42) def test_batch_rho(self): # This test just provides a sanity check, it's not that interesting as it only searches one field. results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=42, - threads=2, fields={'contents': 2.0}) + threads=2, rho=50) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) - self.assertTrue(isinstance(results['q1'][0], JSimpleSearcherResult)) + self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) self.assertEqual(len(results['q1']), 42) + self.assertEqual(results['q1'][9].docid, 'CACM-1725') + self.assertEqual(results['q1'][9].score, 362.0) + self.assertTrue(isinstance(results['q2'], List)) - self.assertTrue(isinstance(results['q2'][0], JSimpleSearcherResult)) + self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) self.assertEqual(len(results['q2']), 42) + self.assertEqual(results['q2'][9].docid, 'CACM-1815') + self.assertEqual(results['q2'][9].score, 392.0) - def test_different_similarity(self): - # qld, default mu - self.searcher.set_qld() - self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet')) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 3.68030, places=5) - self.assertEqual(hits[9].docid, 'CACM-1927') - self.assertAlmostEqual(hits[9].score, 2.53240, places=5) - - # bm25, default parameters - self.searcher.set_bm25() - self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25')) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 4.76550, places=5) - self.assertEqual(hits[9].docid, 'CACM-2516') - self.assertAlmostEqual(hits[9].score, 4.21740, places=5) - - # qld, custom mu - self.searcher.set_qld(100) - self.assertTrue(self.searcher.get_similarity().toString().startswith('LM Dirichlet')) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 6.35580, places=5) - self.assertEqual(hits[9].docid, 'CACM-2631') - self.assertAlmostEqual(hits[9].score, 5.18960, places=5) - - # bm25, custom parameters - self.searcher.set_bm25(0.8, 0.3) - self.assertTrue(self.searcher.get_similarity().toString().startswith('BM25')) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 4.86880, places=5) - self.assertEqual(hits[9].docid, 'CACM-2516') - self.assertAlmostEqual(hits[9].score, 4.33320, places=5) - - def test_rm3(self): - self.searcher.set_rm3() - self.assertTrue(self.searcher.is_using_rm3()) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 2.18010, places=5) - self.assertEqual(hits[9].docid, 'CACM-2516') - self.assertAlmostEqual(hits[9].score, 1.70330, places=5) - - self.searcher.unset_rm3() - self.assertFalse(self.searcher.is_using_rm3()) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 4.76550, places=5) - self.assertEqual(hits[9].docid, 'CACM-2516') - self.assertAlmostEqual(hits[9].score, 4.21740, places=5) - - self.searcher.set_rm3(fb_docs=4, fb_terms=6, original_query_weight=0.3) - self.assertTrue(self.searcher.is_using_rm3()) - - hits = self.searcher.search('information retrieval') - - self.assertEqual(hits[0].docid, 'CACM-3134') - self.assertAlmostEqual(hits[0].score, 2.17190, places=5) - self.assertEqual(hits[9].docid, 'CACM-1457') - self.assertAlmostEqual(hits[9].score, 1.43700, places=5) + # def test_different_similarity(self): def test_ascii(self): - raise NotImplementedError + output = self.searcher.set_ascii_parser() + self.assertEqual(0, output) + - def test_basicparser(self): - raise NotImplementedError - - def test_doc_str(self): - # The doc method is overloaded: if input is str, it's assumed to be an external collection docid. - doc = self.searcher.doc('CACM-0002') - self.assertTrue(isinstance(doc, Document)) - - # These are all equivalent ways to get the docid. - self.assertEqual(doc.lucene_document().getField('id').stringValue(), 'CACM-0002') - self.assertEqual(doc.id(), 'CACM-0002') - self.assertEqual(doc.docid(), 'CACM-0002') - self.assertEqual(doc.get('id'), 'CACM-0002') - - # These are all equivalent ways to get the 'raw' field - self.assertEqual(186, len(doc.raw())) - self.assertEqual(186, len(doc.get('raw'))) - self.assertEqual(186, len(doc.lucene_document().get('raw'))) - self.assertEqual(186, len(doc.lucene_document().getField('raw').stringValue())) - - # These are all equivalent ways to get the 'contents' field - self.assertEqual(154, len(doc.contents())) - self.assertEqual(154, len(doc.get('contents'))) - self.assertEqual(154, len(doc.lucene_document().get('contents'))) - self.assertEqual(154, len(doc.lucene_document().getField('contents').stringValue())) - - # Should return None if we request a docid that doesn't exist - self.assertTrue(self.searcher.doc('foo') is None) - - def test_doc_by_field(self): - self.assertEqual(self.searcher.doc('CACM-3134').docid(), - self.searcher.doc_by_field('id', 'CACM-3134').docid()) - - # Should return None if we request a docid that doesn't exist - self.assertTrue(self.searcher.doc_by_field('foo', 'bar') is None) + def test_basic_parser(self): + output = self.searcher.set_basic_parser() + self.assertEqual(0, output) + + def tearDown(self): os.remove(self.tarball_name) From ca206d40f58e37a847635b4e271ff532706f804a Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Fri, 18 Feb 2022 10:52:12 +0000 Subject: [PATCH 34/35] update pip requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index d3927eb67..4c37a72c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,3 +11,4 @@ nmslib>=2.1.1 onnxruntime>=1.8.1 lightgbm>=3.3.2 spacy>=3.2.1 +pyjass>=0.2a7 From 8c64a014f341c65a40a1fd7a391cc744b0133818 Mon Sep 17 00:00:00 2001 From: Pradeesh Date: Sun, 20 Feb 2022 21:52:33 +0000 Subject: [PATCH 35/35] changed it to 88 cause its lucky --- tests/test_search_pyjass.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_search_pyjass.py b/tests/test_search_pyjass.py index 08b61da97..c3638b240 100644 --- a/tests/test_search_pyjass.py +++ b/tests/test_search_pyjass.py @@ -95,24 +95,24 @@ def test_batch(self): self.assertEqual(results['q2'][9].score, 392.0) def test_basic_k(self): - hits = self.searcher.search('information retrieval', k=100) + hits = self.searcher.search('information retrieval', k=88) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(hits, List)) self.assertTrue(isinstance(hits[0], JASSv2SearcherResult)) - self.assertEqual(len(hits), 100) + self.assertEqual(len(hits), 88) def test_batch_k(self): - results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=100, threads=2) + results = self.searcher.batch_search(['information retrieval', 'search'], ['q1', 'q2'], k=88, threads=2) self.assertEqual(3204, self.searcher.num_docs) self.assertTrue(isinstance(results, Dict)) self.assertTrue(isinstance(results['q1'], List)) self.assertTrue(isinstance(results['q1'][0], JASSv2SearcherResult)) - self.assertEqual(len(results['q1']), 100) + self.assertEqual(len(results['q1']), 88) self.assertTrue(isinstance(results['q2'], List)) self.assertTrue(isinstance(results['q2'][0], JASSv2SearcherResult)) - self.assertEqual(len(results['q2']), 99) + self.assertEqual(len(results['q2']), 88) def test_basic_rho(self): hits = self.searcher.search('information retrieval', k=42, rho=50)