diff --git a/open-simulation-interface b/open-simulation-interface index 296c549..560d23a 160000 --- a/open-simulation-interface +++ b/open-simulation-interface @@ -1 +1 @@ -Subproject commit 296c549e95364f360beb724e39b2c6470d94b780 +Subproject commit 560d23a6e842bb8a71bbb9fcf52c168e75b5e38d diff --git a/osivalidator/osi_general_validator.py b/osivalidator/osi_general_validator.py index 675a94c..a79c165 100755 --- a/osivalidator/osi_general_validator.py +++ b/osivalidator/osi_general_validator.py @@ -5,7 +5,8 @@ import argparse from multiprocessing import Pool, Manager from tqdm import tqdm -import os, sys +import os +import sys sys.path.append(os.path.join(os.path.dirname(__file__), ".")) @@ -14,13 +15,24 @@ import osi_rules import osi_validator_logger import osi_rules_checker - import osi_trace + import linked_proto_field + from format.OSITrace import OSITrace except Exception as e: print( - "Make sure you have installed the requirements with 'pip install -r requirements.txt'!" + "Make sure you have installed the requirements with 'python3 -m pip install -r requirements.txt'!" ) print(e) +# Global variables +manager_ = Manager() +logs_ = manager_.list() +timestamp_analyzed_ = manager_.list() +logger_ = osi_validator_logger.OSIValidatorLogger() +validation_rules_ = osi_rules.OSIRules() +id_to_ts_ = {} +bar_suffix_ = "%(index)d/%(max)d [%(elapsed_td)s]" +message_cache_ = {} + def check_positive_int(value): ivalue = int(value) @@ -108,7 +120,8 @@ def command_line_arguments(): parser.add_argument( "--buffer", "-bu", - help="Set the buffer size to retrieve OSI messages from trace file. Set it to 0 if you do not want to use buffering at all.", + help="Set the buffer size to retrieve OSI messages from trace file." + "Set it to 0 if you do not want to use buffering at all.", default=1000000, type=check_positive_int, required=False, @@ -117,16 +130,6 @@ def command_line_arguments(): return parser.parse_args() -MANAGER = Manager() -LOGS = MANAGER.list() -TIMESTAMP_ANALYZED = MANAGER.list() -LOGGER = osi_validator_logger.OSIValidatorLogger() -VALIDATION_RULES = osi_rules.OSIRules() -ID_TO_TS = {} -BAR_SUFFIX = "%(index)d/%(max)d [%(elapsed_td)s]" -MESSAGE_CACHE = {} - - def main(): """Main method""" @@ -139,34 +142,34 @@ def main(): if not os.path.exists(directory): os.makedirs(directory) - LOGGER.init(args.debug, args.verbose, directory) + logger_.init(args.debug, args.verbose, directory) # Read data print("Reading data ...") - DATA = osi_trace.OSITrace(buffer_size=args.buffer) - DATA.from_file(path=args.data, type_name=args.type, max_index=args.timesteps) + trace_data = OSITrace(buffer_size=args.buffer) + trace_data.from_file(path=args.data, type_name=args.type, max_index=args.timesteps) - if DATA.timestep_count < args.timesteps: + if trace_data.timestep_count < args.timesteps: args.timesteps = -1 # Collect Validation Rules print("Collect validation rules ...") - VALIDATION_RULES.from_yaml_directory(args.rules) + validation_rules_.from_yaml_directory(args.rules) # Pass all timesteps or the number specified if args.timesteps != -1: max_timestep = args.timesteps - LOGGER.info(None, f"Pass the {max_timestep} first timesteps") + logger_.info(None, f"Pass the {max_timestep} first timesteps") else: - LOGGER.info(None, "Pass all timesteps") - max_timestep = DATA.timestep_count + logger_.info(None, "Pass all timesteps") + max_timestep = trace_data.timestep_count # Dividing in several blast to not overload the memory max_timestep_blast = 0 while max_timestep_blast < max_timestep: # Clear log queue - LOGS = MANAGER.list() + logs_ = manager_.list() # Increment the max-timestep to analyze max_timestep_blast += args.blast @@ -174,8 +177,8 @@ def main(): last_of_blast = min(max_timestep_blast, max_timestep) # Cache messages - DATA.cache_messages_in_index_range(first_of_blast, last_of_blast) - MESSAGE_CACHE.update(DATA.message_cache) + trace_data.cache_messages_in_index_range(first_of_blast, last_of_blast) + message_cache_.update(trace_data.message_cache) if args.parallel: # Launch parallel computation @@ -202,9 +205,9 @@ def main(): except Exception as e: print(str(e)) - MESSAGE_CACHE.clear() + message_cache_.clear() - DATA.trace_file.close() + trace_data.trace_file.close() display_results() @@ -217,67 +220,69 @@ def close_pool(pool): def process_timestep(timestep, data_type): """Process one timestep""" - message = MESSAGE_CACHE[timestep] - rule_checker = osi_rules_checker.OSIRulesChecker(LOGGER) + message = linked_proto_field.LinkedProtoField( + message_cache_[timestep], name=data_type + ) + rule_checker = osi_rules_checker.OSIRulesChecker(logger_) timestamp = rule_checker.set_timestamp(message.value.timestamp, timestep) - ID_TO_TS[timestep] = timestamp + id_to_ts_[timestep] = timestamp - LOGGER.log_messages[timestep] = [] - LOGGER.debug_messages[timestep] = [] - LOGGER.info(None, f"Analyze message of timestamp {timestamp}", False) + logger_.log_messages[timestep] = [] + logger_.debug_messages[timestep] = [] + logger_.info(None, f"Analyze message of timestamp {timestamp}", False) - with MANAGER.Lock(): - if timestamp in TIMESTAMP_ANALYZED: - LOGGER.error(timestep, f"Timestamp already exists") - TIMESTAMP_ANALYZED.append(timestamp) + with manager_.Lock(): + if timestamp in timestamp_analyzed_: + logger_.error(timestep, f"Timestamp already exists") + timestamp_analyzed_.append(timestamp) # Check common rules getattr(rule_checker, "is_valid")( - message, VALIDATION_RULES.get_rules().get_type(data_type) + message, validation_rules_.get_rules().get_type(data_type) ) - LOGS.extend(LOGGER.log_messages[timestep]) + logs_.extend(logger_.log_messages[timestep]) def get_message_count(data, data_type="SensorView", from_message=0, to_message=None): # Wrapper function for external use in combination with process_timestep - timesteps = None + time_steps = None if from_message != 0: print("Currently only validation from the first frame (0) is supported!") if to_message is not None: - timesteps = int(to_message) + time_steps = int(to_message) # Read data print("Reading data ...") - DATA = osi_trace.OSITrace(buffer_size=1000000) - DATA.from_file(path=data, type_name=data_type, max_index=timesteps) + trace_data = OSITrace(buffer_size=1000000) + trace_data.from_file(path=data, type_name=data_type, max_index=time_steps) - if DATA.timestep_count < timesteps: - timesteps = -1 + if trace_data.timestep_count < time_steps: + time_steps = -1 # Collect Validation Rules print("Collect validation rules ...") try: - VALIDATION_RULES.from_yaml_directory("osi-validation/rules/") + validation_rules_.from_yaml_directory("osi-validation/rules/") except Exception as e: print("Error collecting validation rules:", e) - # Pass all timesteps or the number specified - if timesteps != -1: - max_timestep = timesteps - LOGGER.info(None, f"Pass the {max_timestep} first timesteps") + # Pass all time_steps or the number specified + if time_steps != -1: + max_timestep = time_steps + logger_.info(None, f"Pass the {max_timestep} first time_steps") else: - LOGGER.info(None, "Pass all timesteps") - max_timestep = DATA.timestep_count + logger_.info(None, "Pass all time_steps") + max_timestep = trace_data.timestep_count # Dividing in several blast to not overload the memory max_timestep_blast = 0 while max_timestep_blast < max_timestep: # Clear log queue - LOGS[:] = [] + logs_[:] = [] # Increment the max-timestep to analyze max_timestep_blast += 500 @@ -285,17 +290,17 @@ def get_message_count(data, data_type="SensorView", from_message=0, to_message=N last_of_blast = min(max_timestep_blast, max_timestep) # Cache messages - DATA.cache_messages_in_index_range(first_of_blast, last_of_blast) - MESSAGE_CACHE.update(DATA.message_cache) + trace_data.cache_messages_in_index_range(first_of_blast, last_of_blast) + message_cache_.update(trace_data.message_cache) - DATA.trace_file.close() + trace_data.trace_file.close() - return len(MESSAGE_CACHE) + return len(message_cache_) -# Synthetize Logs +# Synthesize Logs def display_results(): - return LOGGER.synthetize_results(LOGS) + return logger_.synthetize_results(logs_) if __name__ == "__main__": diff --git a/osivalidator/osi_trace.py b/osivalidator/osi_trace.py deleted file mode 100644 index 990f833..0000000 --- a/osivalidator/osi_trace.py +++ /dev/null @@ -1,350 +0,0 @@ -""" -Module that contains OSIDataContainer class to handle and manage OSI traces. -""" -from collections import deque -import time -import struct - -from osi3.osi_sensorview_pb2 import SensorView -from osi3.osi_groundtruth_pb2 import GroundTruth -from osi3.osi_sensordata_pb2 import SensorData -from tqdm import tqdm - -import warnings - -warnings.simplefilter("default") -import os, sys - -sys.path.append(os.path.join(os.path.dirname(__file__), ".")) -import linked_proto_field - -SEPARATOR = b"$$__$$" -SEPARATOR_LENGTH = len(SEPARATOR) - - -def get_size_from_file_stream(file_object): - """ - Return a file size from a file stream given in parameters - """ - current_position = file_object.tell() - file_object.seek(0, 2) - size = file_object.tell() - file_object.seek(current_position) - return size - - -MESSAGES_TYPE = { - "SensorView": SensorView, - "GroundTruth": GroundTruth, - "SensorData": SensorData, -} - - -class OSITrace: - """This class wrap OSI data. It can import and decode OSI traces.""" - - def __init__(self, buffer_size, show_progress=True, type_name="SensorView"): - self.trace_file = None - self.message_offsets = None - self.buffer_size = buffer_size - self._int_length = len(struct.pack(" self.buffer_size * (counter + 1) - - # Check if reached end of file - if self.trace_file.tell() == trace_size: - self.retrieved_trace_size = self.message_offsets[-1] - self.message_offsets.pop() # Remove the last element since after that there is no message coming - break - - while eof: - # Counter increment and cursor placement update. The cursor is set absolute in the file. - if message_offset >= len(serialized_message): - progress_bar.update(message_offset - last_offset) - last_offset = message_offset - counter += 1 - self.trace_file.seek(counter * self.buffer_size) - eof = False - - else: - serialized_message = self.trace_file.read() - while message_offset < trace_size: - message_length = struct.unpack( - "= trace_size: - break - self.message_offsets.append(message_offset) - progress_bar.update(message_offset - last_offset) - last_offset = message_offset - - if eof: - self.retrieved_trace_size = trace_size - else: - self.retrieved_trace_size = self.message_offsets[-1] - self.message_offsets.pop() - - if self.show_progress: - progress_bar.close() - print( - len(self.message_offsets), - "messages has been discovered in", - time.time() - start_time, - "s", - ) - - return len(self.message_offsets) - - def get_message_by_index(self, index): - """ - Get a message by its index. Try first to get it from the cache made - by the method ``cache_messages_in_index_range``. - """ - message = self.message_cache.get(index, None) - - if message is not None: - return message - - message = next(self.get_messages_in_index_range(index, index + 1)) - return linked_proto_field.LinkedProtoField(message, name=self.type_name) - - def get_messages_in_index_range(self, begin, end): - """ - Yield an iterator over messages of indexes between begin and end included. - """ - - self.trace_file.seek(self.message_offsets[begin]) - abs_first_offset = self.message_offsets[begin] - abs_last_offset = ( - self.message_offsets[end] - if end < len(self.message_offsets) - else self.retrieved_trace_size - ) - - rel_message_offsets = [ - abs_message_offset - abs_first_offset - for abs_message_offset in self.message_offsets[begin:end] - ] - - if self.path.lower().endswith((".txt")): - message_sequence_len = abs_last_offset - abs_first_offset - SEPARATOR_LENGTH - serialized_messages_extract = self.trace_file.read(message_sequence_len) - - pbar = tqdm(rel_message_offsets) - for rel_index, rel_message_offset in enumerate(pbar): - pbar.set_description( - f"Processing index {rel_index} with offset {rel_message_offset}" - ) - rel_begin = rel_message_offset - rel_end = ( - rel_message_offsets[rel_index + 1] - SEPARATOR_LENGTH - if rel_index + 1 < len(rel_message_offsets) - else message_sequence_len - ) - - message = MESSAGES_TYPE[self.type_name]() - serialized_message = serialized_messages_extract[rel_begin:rel_end] - message.ParseFromString(serialized_message) - yield linked_proto_field.LinkedProtoField(message, name=self.type_name) - - elif self.path.lower().endswith((".osi")): - message_sequence_len = abs_last_offset - abs_first_offset - serialized_messages_extract = self.trace_file.read(message_sequence_len) - message_length = 0 - i = 0 - while i < len(serialized_messages_extract): - message = MESSAGES_TYPE[self.type_name]() - message_length = struct.unpack( - "