diff --git a/testcases/loading/test_loading.py b/testcases/loading/test_loading.py index e3ed207..fdc52eb 100644 --- a/testcases/loading/test_loading.py +++ b/testcases/loading/test_loading.py @@ -35,9 +35,9 @@ class SimpleLoadTest: """A helper class to generate a simple load on a SMB server""" - instance_num = 0 - max_files = 10 - test_string = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" + instance_num: int = 0 + max_files: int = 10 + test_string: str = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" def __init__( self, @@ -46,10 +46,12 @@ def __init__( username: str, passwd: str, testdir: str, + testfile: str = "", ): self.idnum: int = type(self).instance_num type(self).instance_num += 1 + self.testfile = testfile self.rootpath: str = f"{testdir}/test{self.idnum}" self.files: typing.List[str] = [] self.thread = None @@ -119,20 +121,31 @@ def _simple_run(self, op=""): self._simple_run(op="write") return self.stats["read"] += 1 - self.smbclient.read_text(file) + if self.testfile: + tfile = testhelper.get_tmp_file() + with open(tfile, "wb") as fd: + self.smbclient.read(file, fd) + os.unlink(tfile) + else: + self.smbclient.read_text(file) elif op == "write": file = self._new_file() if not file: return self.stats["write"] += 1 - self.smbclient.write_text(file, type(self).test_string) + if self.testfile: + with open(self.testfile, "rb") as fd: + self.smbclient.write(file, fd) + else: + self.smbclient.write_text(file, type(self).test_string) elif op == "delete": file = self._del_file() if not file: return self.stats["delete"] += 1 self.smbclient.unlink(file) - except IOError as error: + # Catch all errors + except Exception as error: print(error) self.stats["error"] += 1 @@ -184,12 +197,15 @@ def __init__( username: str, passwd: str, testdir: str, + testfile: str = "", ): self.server: str = hostname self.share: str = share self.username: str = username self.password: str = passwd self.testdir: str = testdir + self.testfile = testfile + self.connections: typing.List[SimpleLoadTest] = [] self.start_time: float = 0 self.stop_time: float = 0 @@ -207,6 +223,7 @@ def set_connection_num(self, num: int) -> None: self.username, self.password, self.testdir, + self.testfile, ) self.connections.append(smbclient) elif cnum > num: @@ -263,6 +280,7 @@ def start_process( ret_queue: Queue, mount_params: typing.Dict[str, str], testdir: str, + testfile: str = "", ) -> None: """Start function for test processes""" loadtest: LoadTest = LoadTest( @@ -271,6 +289,7 @@ def start_process( mount_params["username"], mount_params["password"], testdir, + testfile, ) loadtest.set_connection_num(numcons) loadtest.start_tests(test_runtime) @@ -293,6 +312,8 @@ def generate_loading_check() -> typing.List[tuple[str, str]]: @pytest.mark.parametrize("hostname,sharename", generate_loading_check()) def test_loading(hostname: str, sharename: str) -> None: + # Get a tmp file of size 4K + tmpfile = testhelper.get_tmp_file(size=4 * 1024) mount_params: dict[str, str] = testhelper.get_mount_parameters( test_info, sharename ) @@ -322,6 +343,7 @@ def test_loading(hostname: str, sharename: str) -> None: ret_queue, mount_params, process_testdir, + tmpfile, ), ) processes.append(process) @@ -359,6 +381,7 @@ def test_loading(hostname: str, sharename: str) -> None: smbclient.rmdir(testdir) smbclient.disconnect() + os.unlink(tmpfile) print_stats("Total:", total_stats) assert (