From 33b5b938c1cfad446bf7324cb0e0c7625c2c89f1 Mon Sep 17 00:00:00 2001 From: Watchful1 Date: Sat, 28 Jan 2023 11:39:27 -0800 Subject: [PATCH] Change to FileHandle --- scripts/combine_folder_multiprocess.py | 201 +++++++++++++++---------- 1 file changed, 120 insertions(+), 81 deletions(-) diff --git a/scripts/combine_folder_multiprocess.py b/scripts/combine_folder_multiprocess.py index affd154..c710d5e 100644 --- a/scripts/combine_folder_multiprocess.py +++ b/scripts/combine_folder_multiprocess.py @@ -61,6 +61,84 @@ class FileConfig: return f"{self.input_path} : {self.output_path} : {self.file_size} : {self.complete} : {self.bytes_processed} : {self.lines_processed}" +# another convenience object to read and write from both zst files and ndjson files +class FileHandle: + def __init__(self, path): + self.path = path + if self.path.endswith(".zst"): + self.is_compressed = True + elif self.path.endswith(".ndjson"): + self.is_compressed = False + else: + raise TypeError(f"File type not supported for writing {self.path}") + + self.write_handle = None + self.other_handle = None + self.newline_encoded = "\n".encode('utf-8') + + # recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes + @staticmethod + def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0): + chunk = reader.read(chunk_size) + bytes_read += chunk_size + if previous_chunk is not None: + chunk = previous_chunk + chunk + try: + return chunk.decode() + except UnicodeDecodeError: + if bytes_read > max_window_size: + raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes") + return FileHandle.read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read) + + # open a zst compressed ndjson file, or a regular uncompressed ndjson file and yield lines one at a time + # also passes back file progress + def yield_lines(self): + if self.is_compressed: + with open(self.path, 'rb') as file_handle: + buffer = '' + reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle) + while True: + chunk = FileHandle.read_and_decode(reader, 2**27, (2**29) * 2) + if not chunk: + break + lines = (buffer + chunk).split("\n") + + for line in lines[:-1]: + yield line, file_handle.tell() + + buffer = lines[-1] + reader.close() + + else: + with open(self.path, 'r') as file_handle: + line = file_handle.readline() + while line: + yield line.rstrip('\n'), file_handle.tell() + line = file_handle.readline() + + # write a line, opening the appropriate handle + def write_line(self, line): + if self.write_handle is None: + if self.is_compressed: + self.other_handle = open(self.path, 'wb') + self.write_handle = zstandard.ZstdCompressor().stream_writer(self.other_handle) + else: + self.write_handle = open(self.path, 'w', encoding="utf-8") + + if self.is_compressed: + self.write_handle.write(line.encode('utf-8')) + self.write_handle.write(self.newline_encoded) + else: + self.write_handle.write(line) + self.write_handle.write("\n") + + def close(self): + if self.write_handle: + self.write_handle.close() + if self.other_handle: + self.other_handle.close() + + # used for calculating running average of read speed class Queue: def __init__(self, max_size): @@ -108,51 +186,19 @@ def load_file_list(status_json): return None, None, None -# recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes -def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0): - chunk = reader.read(chunk_size) - bytes_read += chunk_size - if previous_chunk is not None: - chunk = previous_chunk + chunk - try: - return chunk.decode() - except UnicodeDecodeError: - if bytes_read > max_window_size: - raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes") - return read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read) - - -# open a zst compressed ndjson file and yield lines one at a time -# also passes back file progress -def read_lines_zst(file_name): - with open(file_name, 'rb') as file_handle: - buffer = '' - reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle) - while True: - chunk = read_and_decode(reader, 2**27, (2**29) * 2) - if not chunk: - break - lines = (buffer + chunk).split("\n") - - for line in lines[:-1]: - yield line, file_handle.tell() - - buffer = lines[-1] - reader.close() - - # base of each separate process. Loads a file, iterates through lines and writes out # the ones where the `field` of the object matches `value`. Also passes status # information back to the parent via a queue -def process_file(file, queue, field, value, values, case_sensitive): - output_file = None +def process_file(file, queue, field, value, values): queue.put(file) + input_handle = FileHandle(file.input_path) + output_handle = FileHandle(file.output_path) try: - for line, file_bytes_processed in read_lines_zst(file.input_path): + for line, file_bytes_processed in input_handle.yield_lines(): try: obj = json.loads(line) matched = False - observed = obj[field] if case_sensitive else obj[field].lower() + observed = obj[field].lower() if value is not None: if observed == value: matched = True @@ -160,10 +206,7 @@ def process_file(file, queue, field, value, values, case_sensitive): matched = True if matched: - if output_file is None: - output_file = open(file.output_path, 'w', encoding="utf-8") - output_file.write(line) - output_file.write("\n") + output_handle.write_line(line) except (KeyError, json.JSONDecodeError) as err: file.error_lines += 1 file.lines_processed += 1 @@ -171,9 +214,7 @@ def process_file(file, queue, field, value, values, case_sensitive): file.bytes_processed = file_bytes_processed queue.put(file) - if output_file is not None: - output_file.close() - + output_handle.close() file.complete = True file.bytes_processed = file.file_size except Exception as err: @@ -191,8 +232,8 @@ if __name__ == '__main__': parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value. Supports a comma separated list. This is case sensitive", default="pushshift") parser.add_argument("--value_list", help="A file of newline separated values to use. Overrides the value param if it is set", default=None) parser.add_argument("--processes", help="Number of processes to use", default=10, type=int) - parser.add_argument("--case-sensitive", help="Matching should be case sensitive", action="store_true") parser.add_argument("--file_filter", help="Regex filenames have to match to be processed", default="^rc_|rs_") + parser.add_argument("--compress_intermediate", help="Compress the intermediate files, use if the filter will result in a very large amount of data", action="store_true") parser.add_argument( "--error_rate", help= "Percentage as an integer from 0 to 100 of the lines where the field can be missing. For the subreddit field especially, " @@ -201,7 +242,7 @@ if __name__ == '__main__': script_type = "split" args = parser.parse_args() - arg_string = f"{args.field}:{args.value}:{args.case_sensitive}" + arg_string = f"{args.field}:{(args.value if args.value else args.value_list)}" if args.debug: log.setLevel(logging.DEBUG) @@ -212,9 +253,6 @@ if __name__ == '__main__': else: log.info(f"Writing output to working folder") - if not args.case_sensitive: - args.value = args.value.lower() - value = None values = None if args.value_list: @@ -222,7 +260,7 @@ if __name__ == '__main__': with open(args.value_list, 'r') as value_list_handle: values = set() for line in value_list_handle: - values.add(line.strip()) + values.add(line.strip().lower()) log.info(f"Comparing {args.field} against {len(values)} values") else: @@ -230,10 +268,10 @@ if __name__ == '__main__': if len(value_strings) > 1: values = set() for value_inner in value_strings: - values.add(value_inner) + values.add(value_inner.lower()) log.info(f"Checking field {args.field} for values {(', '.join(value_strings))}") elif len(value_strings) == 1: - value = value_strings[0] + value = value_strings[0].lower() log.info(f"Checking field {args.field} for value {value}") else: log.info(f"Invalid value specified, aborting: {args.value}") @@ -259,7 +297,7 @@ if __name__ == '__main__': for file_name in files: if file_name.endswith(".zst") and re.search(args.file_filter, file_name, re.IGNORECASE) is not None: input_path = os.path.join(subdir, file_name) - output_path = os.path.join(args.working, file_name[:-4]) + output_path = os.path.join(args.working, f"{file_name[:-4]}.{('zst' if args.compress_intermediate else 'ndjson')}") input_files.append(FileConfig(input_path, output_path=output_path)) save_file_list(input_files, args.working, status_json, arg_string, script_type) @@ -295,7 +333,7 @@ if __name__ == '__main__': log.info(f"Processing file: {file.input_path}") # start the workers with multiprocessing.Pool(processes=min(args.processes, len(files_to_process))) as pool: - workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values, args.case_sensitive) for file in files_to_process], chunksize=1, error_callback=log.info) + workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values) for file in files_to_process], chunksize=1, error_callback=log.info) while not workers.ready(): # loop until the workers are all done, pulling in status messages as they are sent file_update = queue.get() @@ -383,7 +421,7 @@ if __name__ == '__main__': split = False for working_file_path in working_file_paths: files_combined += 1 - log.info(f"Reading {files_combined}/{len(working_file_paths)} : {os.path.split(working_file_path)[1]}") + log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}") working_file_name = os.path.split(working_file_path)[1] if working_file_name.startswith("RS"): file_type = "submissions" @@ -392,37 +430,38 @@ if __name__ == '__main__': else: log.warning(f"Unknown working file type, skipping: {working_file_name}") continue + input_handle = FileHandle(working_file_path) if file_type not in output_handles: output_handles[file_type] = {} file_type_handles = output_handles[file_type] - with open(working_file_path, 'r') as input_file: - for line in input_file: - output_lines += 1 - if split: - obj = json.loads(line) - observed_case = obj[args.field] - else: - observed_case = value - observed = observed_case if args.case_sensitive else observed_case.lower() - if observed not in file_type_handles: - if args.output: - if not os.path.exists(args.output): - os.makedirs(args.output) - output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst") - else: - output_file_path = f"{observed_case}_{file_type}.zst" - log.info(f"Writing to file {output_file_path}") - file_handle = open(output_file_path, 'wb') - writer = zstandard.ZstdCompressor().stream_writer(file_handle) - file_type_handles[observed] = writer - all_handles.append(writer) - all_handles.append(file_handle) - else: - writer = file_type_handles[observed] - encoded_line = line.encode('utf-8') - writer.write(encoded_line) + for line, file_bytes_processed in input_handle.yield_lines(): + output_lines += 1 + if split: + obj = json.loads(line) + observed_case = obj[args.field] + else: + observed_case = value + observed = observed_case.lower() + if observed not in file_type_handles: + if args.output: + if not os.path.exists(args.output): + os.makedirs(args.output) + output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst") + else: + output_file_path = f"{observed_case}_{file_type}.zst" + log.debug(f"Writing to file {output_file_path}") + output_handle = FileHandle(output_file_path) + file_type_handles[observed] = output_handle + all_handles.append(output_handle) + else: + output_handle = file_type_handles[observed] + output_handle.write_line(line) + if output_lines % 100000 == 0: + log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}") + + log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines") for handle in all_handles: handle.close()