Merge remote-tracking branch 'origin/master'

This commit is contained in:
Watchful1 2022-07-15 23:39:45 -07:00
commit 3fa63048e3

View file

@ -135,17 +135,18 @@ def read_lines_zst(file_name):
# base of each separate process. Loads a file, iterates through lines and writes out
# the ones where the `field` of the object matches `value`. Also passes status
# information back to the parent via a queue
def process_file(file, working_folder, queue, field, value, values):
def process_file(file, working_folder, queue, field, value, values, case_sensitive):
output_file = None
try:
for line, file_bytes_processed in read_lines_zst(file.input_path):
try:
obj = json.loads(line)
matched = False
observed = obj[field] if case_sensitive else obj[field].lower()
if value is not None:
if obj[field] == value:
if observed == value:
matched = True
elif obj[field] in values:
elif observed in values:
matched = True
if matched:
@ -181,11 +182,13 @@ if __name__ == '__main__':
parser.add_argument("--field", help="When deciding what lines to keep, use this field for comparisons", default="subreddit")
parser.add_argument("--value", help="When deciding what lines to keep, compare the field to this value. Supports a comma separated list. This is case sensitive", default="pushshift")
parser.add_argument("--processes", help="Number of processes to use", default=10, type=int)
parser.add_argument("--case-sensitive", help="Matching should be case sensitive", action="store_true")
parser.add_argument(
"--error_rate", help=
"Percentage as an integer from 0 to 100 of the lines where the field can be missing. For the subreddit field especially, "
"there are a number of posts that simply don't have a subreddit attached", default=1, type=int)
parser.add_argument("--debug", help="Enable debug logging", action='store_const', const=True, default=False)
args = parser.parse_args()
if args.debug:
@ -194,6 +197,9 @@ if __name__ == '__main__':
log.info(f"Loading files from: {args.input}")
log.info(f"Writing output to: {(os.path.join(args.output, args.name + '.zst'))}")
if not args.case_sensitive:
args.value = args.value.lower()
value_strings = args.value.split(",")
value = None
values = None
@ -254,7 +260,7 @@ if __name__ == '__main__':
log.debug(f"Processing file: {file.input_path}")
# start the workers
with multiprocessing.Pool(processes=min(args.processes, len(files_to_process))) as pool:
workers = pool.starmap_async(process_file, [(file, working_folder, queue, args.field, value, values) for file in files_to_process], error_callback=log.info)
workers = pool.starmap_async(process_file, [(file, working_folder, queue, args.field, value, values, args.case_sensitive) for file in files_to_process], error_callback=log.info)
while not workers.ready():
# loop until the workers are all done, pulling in status messages as they are sent
file_update = queue.get()