diff --git a/chattool/__init__.py b/chattool/__init__.py index ab56492..29e94bb 100644 --- a/chattool/__init__.py +++ b/chattool/__init__.py @@ -2,7 +2,7 @@ __author__ = """Rex Wang""" __email__ = '1073853456@qq.com' -__version__ = '3.3.1' +__version__ = '3.3.2' import os, sys, requests, json from .chattype import Chat, Resp diff --git a/chattool/checkpoint.py b/chattool/checkpoint.py index d93c34c..274326f 100644 --- a/chattool/checkpoint.py +++ b/chattool/checkpoint.py @@ -1,7 +1,8 @@ -import json, warnings, os +import json, os from typing import List, Dict, Union, Callable, Any from .chattype import Chat import tqdm +from loguru import logger def load_chats( checkpoint:str): """Load chats from a checkpoint file @@ -23,18 +24,16 @@ def load_chats( checkpoint:str): if len(txts) == 1 and txts[0] == '': return [] # get the chatlogs logs = [json.loads(txt) for txt in txts] - chat_size, chatlogs = 1, [None] - for log in logs: - idx = log['index'] - if idx >= chat_size: # extend chatlogs - chatlogs.extend([None] * (idx - chat_size + 1)) - chat_size = idx + 1 - chatlogs[idx] = log['chat_log'] + # mapping from index to chat object + idx2chatlog = { log['index']: Chat(log['chat_log']) for log in logs } + max_index = max(idx2chatlog.keys()) + chat_objects = [ idx2chatlog.get(index, None) for index in range(max_index+1)] + num_unfinished = chat_objects.count(None) # check if there are missing chatlogs - if None in chatlogs: - warnings.warn(f"checkpoint file {checkpoint} has unfinished chats") + if num_unfinished > 0: + logger.warning(f"checkpoint file {checkpoint} has {num_unfinished}/{max_index+1} unfinished chats") # return Chat class - return [Chat(chat_log) if chat_log is not None else None for chat_log in chatlogs] + return chat_objects def process_chats( data:List[Any] , data2chat:Callable[[Any], Chat] @@ -59,7 +58,7 @@ def process_chats( data:List[Any] ## load chats from the checkpoint file chats = load_chats(checkpoint) if len(chats) > len(data): - warnings.warn(f"checkpoint file {checkpoint} has more chats than the data to be processed") + logger.warning(f"checkpoint file {checkpoint} has more chats than the data to be processed") return chats[:len(data)] chats.extend([None] * (len(data) - len(chats))) ## process chats @@ -69,4 +68,4 @@ def process_chats( data:List[Any] chat = data2chat(data[i]) chat.save(checkpoint, mode='a', index=i) chats[i] = chat - return chats \ No newline at end of file + return chats diff --git a/setup.py b/setup.py index 3ed4b55..5a62168 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ with open('README.md') as readme_file: readme = readme_file.read() -VERSION = '3.3.1' +VERSION = '3.3.2' requirements = [ 'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',