Add support for multiple values

This commit is contained in:
Watchful1 2021-10-14 19:33:25 -07:00
parent 4501ec236f
commit 50be918a1c
4 changed files with 74 additions and 5 deletions

View file

@ -130,13 +130,20 @@ def read_lines_zst(file_name):
# base of each separate process. Loads a file, iterates through lines and writes out
# the ones where the `field` of the object matches `value`. Also passes status
# information back to the parent via a queue
def process_file(file, working_folder, queue, field, value):
def process_file(file, working_folder, queue, field, value, values):
output_file = None
try:
for line, file_bytes_processed in read_lines_zst(file.input_path):
try:
obj = json.loads(line)
if obj[field] == value:
matched = False
if value is not None:
if obj[field] == value:
matched = True
elif obj[field] in values:
matched = True
if matched:
if output_file is None:
if file.output_path is None:
created = datetime.utcfromtimestamp(int(obj['created_utc']))
@ -167,7 +174,7 @@ if __name__ == '__main__':
parser.add_argument("output", help="The output folder to store temporary files in and write the output to")
parser.add_argument("--name", help="What to name the output file", default="pushshift")
parser.add_argument("--field", help="When deciding what lines to keep, use this field for comparisons", default="subreddit")
parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value", default="pushshift")
parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value. Supports a comma separated list", default="pushshift")
parser.add_argument("--processes", help="Number of processes to use", default=10, type=int)
parser.add_argument(
"--error_rate", help=
@ -182,6 +189,21 @@ if __name__ == '__main__':
log.info(f"Loading files from: {args.input}")
log.info(f"Writing output to: {(os.path.join(args.output, args.name + '.zst'))}")
value_strings = args.value.split(",")
value = None
values = None
if len(value_strings) > 1:
values = set()
for value_inner in value_strings:
values.add(value_inner)
log.info(f"Checking field {args.field} for values {(', '.join(value_strings))}")
elif len(value_strings) == 1:
value = value_strings[0]
log.info(f"Checking field {args.field} for value {value}")
else:
log.info(f"Invalid value specified, aborting: {args.value}")
sys.exit()
multiprocessing.set_start_method('spawn')
queue = multiprocessing.Manager().Queue()
input_files = load_file_list(args.output, args.name)
@ -226,7 +248,7 @@ if __name__ == '__main__':
log.debug(f"Processing file: {file.input_path}")
# start the workers
with multiprocessing.Pool(processes=min(args.processes, len(files_to_process))) as pool:
workers = pool.starmap_async(process_file, [(file, working_folder, queue, args.field, args.value) for file in files_to_process], error_callback=log.info)
workers = pool.starmap_async(process_file, [(file, working_folder, queue, args.field, value, values) for file in files_to_process], error_callback=log.info)
while not workers.ready():
# loop until the workers are all done, pulling in status messages as they are sent
file_update = queue.get()