Change to FileHandle

This commit is contained in:
Watchful1 2023-01-28 11:39:27 -08:00
parent 87d2b22a73
commit 33b5b938c1

View file

@ -61,6 +61,84 @@ class FileConfig:
return f"{self.input_path} : {self.output_path} : {self.file_size} : {self.complete} : {self.bytes_processed} : {self.lines_processed}"
# another convenience object to read and write from both zst files and ndjson files
class FileHandle:
def __init__(self, path):
self.path = path
if self.path.endswith(".zst"):
self.is_compressed = True
elif self.path.endswith(".ndjson"):
self.is_compressed = False
else:
raise TypeError(f"File type not supported for writing {self.path}")
self.write_handle = None
self.other_handle = None
self.newline_encoded = "\n".encode('utf-8')
# recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes
@staticmethod
def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
chunk = reader.read(chunk_size)
bytes_read += chunk_size
if previous_chunk is not None:
chunk = previous_chunk + chunk
try:
return chunk.decode()
except UnicodeDecodeError:
if bytes_read > max_window_size:
raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
return FileHandle.read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)
# open a zst compressed ndjson file, or a regular uncompressed ndjson file and yield lines one at a time
# also passes back file progress
def yield_lines(self):
if self.is_compressed:
with open(self.path, 'rb') as file_handle:
buffer = ''
reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle)
while True:
chunk = FileHandle.read_and_decode(reader, 2**27, (2**29) * 2)
if not chunk:
break
lines = (buffer + chunk).split("\n")
for line in lines[:-1]:
yield line, file_handle.tell()
buffer = lines[-1]
reader.close()
else:
with open(self.path, 'r') as file_handle:
line = file_handle.readline()
while line:
yield line.rstrip('\n'), file_handle.tell()
line = file_handle.readline()
# write a line, opening the appropriate handle
def write_line(self, line):
if self.write_handle is None:
if self.is_compressed:
self.other_handle = open(self.path, 'wb')
self.write_handle = zstandard.ZstdCompressor().stream_writer(self.other_handle)
else:
self.write_handle = open(self.path, 'w', encoding="utf-8")
if self.is_compressed:
self.write_handle.write(line.encode('utf-8'))
self.write_handle.write(self.newline_encoded)
else:
self.write_handle.write(line)
self.write_handle.write("\n")
def close(self):
if self.write_handle:
self.write_handle.close()
if self.other_handle:
self.other_handle.close()
# used for calculating running average of read speed
class Queue:
def __init__(self, max_size):
@ -108,51 +186,19 @@ def load_file_list(status_json):
return None, None, None
# recursively decompress and decode a chunk of bytes. If there's a decode error then read another chunk and try with that, up to a limit of max_window_size bytes
def read_and_decode(reader, chunk_size, max_window_size, previous_chunk=None, bytes_read=0):
chunk = reader.read(chunk_size)
bytes_read += chunk_size
if previous_chunk is not None:
chunk = previous_chunk + chunk
try:
return chunk.decode()
except UnicodeDecodeError:
if bytes_read > max_window_size:
raise UnicodeError(f"Unable to decode frame after reading {bytes_read:,} bytes")
return read_and_decode(reader, chunk_size, max_window_size, chunk, bytes_read)
# open a zst compressed ndjson file and yield lines one at a time
# also passes back file progress
def read_lines_zst(file_name):
with open(file_name, 'rb') as file_handle:
buffer = ''
reader = zstandard.ZstdDecompressor(max_window_size=2**31).stream_reader(file_handle)
while True:
chunk = read_and_decode(reader, 2**27, (2**29) * 2)
if not chunk:
break
lines = (buffer + chunk).split("\n")
for line in lines[:-1]:
yield line, file_handle.tell()
buffer = lines[-1]
reader.close()
# base of each separate process. Loads a file, iterates through lines and writes out
# the ones where the `field` of the object matches `value`. Also passes status
# information back to the parent via a queue
def process_file(file, queue, field, value, values, case_sensitive):
output_file = None
def process_file(file, queue, field, value, values):
queue.put(file)
input_handle = FileHandle(file.input_path)
output_handle = FileHandle(file.output_path)
try:
for line, file_bytes_processed in read_lines_zst(file.input_path):
for line, file_bytes_processed in input_handle.yield_lines():
try:
obj = json.loads(line)
matched = False
observed = obj[field] if case_sensitive else obj[field].lower()
observed = obj[field].lower()
if value is not None:
if observed == value:
matched = True
@ -160,10 +206,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
matched = True
if matched:
if output_file is None:
output_file = open(file.output_path, 'w', encoding="utf-8")
output_file.write(line)
output_file.write("\n")
output_handle.write_line(line)
except (KeyError, json.JSONDecodeError) as err:
file.error_lines += 1
file.lines_processed += 1
@ -171,9 +214,7 @@ def process_file(file, queue, field, value, values, case_sensitive):
file.bytes_processed = file_bytes_processed
queue.put(file)
if output_file is not None:
output_file.close()
output_handle.close()
file.complete = True
file.bytes_processed = file.file_size
except Exception as err:
@ -191,8 +232,8 @@ if __name__ == '__main__':
parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value. Supports a comma separated list. This is case sensitive", default="pushshift")
parser.add_argument("--value_list", help="A file of newline separated values to use. Overrides the value param if it is set", default=None)
parser.add_argument("--processes", help="Number of processes to use", default=10, type=int)
parser.add_argument("--case-sensitive", help="Matching should be case sensitive", action="store_true")
parser.add_argument("--file_filter", help="Regex filenames have to match to be processed", default="^rc_|rs_")
parser.add_argument("--compress_intermediate", help="Compress the intermediate files, use if the filter will result in a very large amount of data", action="store_true")
parser.add_argument(
"--error_rate", help=
"Percentage as an integer from 0 to 100 of the lines where the field can be missing. For the subreddit field especially, "
@ -201,7 +242,7 @@ if __name__ == '__main__':
script_type = "split"
args = parser.parse_args()
arg_string = f"{args.field}:{args.value}:{args.case_sensitive}"
arg_string = f"{args.field}:{(args.value if args.value else args.value_list)}"
if args.debug:
log.setLevel(logging.DEBUG)
@ -212,9 +253,6 @@ if __name__ == '__main__':
else:
log.info(f"Writing output to working folder")
if not args.case_sensitive:
args.value = args.value.lower()
value = None
values = None
if args.value_list:
@ -222,7 +260,7 @@ if __name__ == '__main__':
with open(args.value_list, 'r') as value_list_handle:
values = set()
for line in value_list_handle:
values.add(line.strip())
values.add(line.strip().lower())
log.info(f"Comparing {args.field} against {len(values)} values")
else:
@ -230,10 +268,10 @@ if __name__ == '__main__':
if len(value_strings) > 1:
values = set()
for value_inner in value_strings:
values.add(value_inner)
values.add(value_inner.lower())
log.info(f"Checking field {args.field} for values {(', '.join(value_strings))}")
elif len(value_strings) == 1:
value = value_strings[0]
value = value_strings[0].lower()
log.info(f"Checking field {args.field} for value {value}")
else:
log.info(f"Invalid value specified, aborting: {args.value}")
@ -259,7 +297,7 @@ if __name__ == '__main__':
for file_name in files:
if file_name.endswith(".zst") and re.search(args.file_filter, file_name, re.IGNORECASE) is not None:
input_path = os.path.join(subdir, file_name)
output_path = os.path.join(args.working, file_name[:-4])
output_path = os.path.join(args.working, f"{file_name[:-4]}.{('zst' if args.compress_intermediate else 'ndjson')}")
input_files.append(FileConfig(input_path, output_path=output_path))
save_file_list(input_files, args.working, status_json, arg_string, script_type)
@ -295,7 +333,7 @@ if __name__ == '__main__':
log.info(f"Processing file: {file.input_path}")
# start the workers
with multiprocessing.Pool(processes=min(args.processes, len(files_to_process))) as pool:
workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values, args.case_sensitive) for file in files_to_process], chunksize=1, error_callback=log.info)
workers = pool.starmap_async(process_file, [(file, queue, args.field, value, values) for file in files_to_process], chunksize=1, error_callback=log.info)
while not workers.ready():
# loop until the workers are all done, pulling in status messages as they are sent
file_update = queue.get()
@ -383,7 +421,7 @@ if __name__ == '__main__':
split = False
for working_file_path in working_file_paths:
files_combined += 1
log.info(f"Reading {files_combined}/{len(working_file_paths)} : {os.path.split(working_file_path)[1]}")
log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}")
working_file_name = os.path.split(working_file_path)[1]
if working_file_name.startswith("RS"):
file_type = "submissions"
@ -392,37 +430,38 @@ if __name__ == '__main__':
else:
log.warning(f"Unknown working file type, skipping: {working_file_name}")
continue
input_handle = FileHandle(working_file_path)
if file_type not in output_handles:
output_handles[file_type] = {}
file_type_handles = output_handles[file_type]
with open(working_file_path, 'r') as input_file:
for line in input_file:
output_lines += 1
if split:
obj = json.loads(line)
observed_case = obj[args.field]
else:
observed_case = value
observed = observed_case if args.case_sensitive else observed_case.lower()
if observed not in file_type_handles:
if args.output:
if not os.path.exists(args.output):
os.makedirs(args.output)
output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst")
else:
output_file_path = f"{observed_case}_{file_type}.zst"
log.info(f"Writing to file {output_file_path}")
file_handle = open(output_file_path, 'wb')
writer = zstandard.ZstdCompressor().stream_writer(file_handle)
file_type_handles[observed] = writer
all_handles.append(writer)
all_handles.append(file_handle)
else:
writer = file_type_handles[observed]
encoded_line = line.encode('utf-8')
writer.write(encoded_line)
for line, file_bytes_processed in input_handle.yield_lines():
output_lines += 1
if split:
obj = json.loads(line)
observed_case = obj[args.field]
else:
observed_case = value
observed = observed_case.lower()
if observed not in file_type_handles:
if args.output:
if not os.path.exists(args.output):
os.makedirs(args.output)
output_file_path = os.path.join(args.output, f"{observed_case}_{file_type}.zst")
else:
output_file_path = f"{observed_case}_{file_type}.zst"
log.debug(f"Writing to file {output_file_path}")
output_handle = FileHandle(output_file_path)
file_type_handles[observed] = output_handle
all_handles.append(output_handle)
else:
output_handle = file_type_handles[observed]
output_handle.write_line(line)
if output_lines % 100000 == 0:
log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines : {os.path.split(working_file_path)[1]}")
log.info(f"From {files_combined}/{len(working_file_paths)} files to {len(all_handles):,} output handles : {output_lines:,} lines")
for handle in all_handles:
handle.close()