Review comments

This commit is contained in:
Kegan Dougal 2016-11-22 13:42:11 +00:00
parent 6d4e6d4cba
commit c3d963ac24
2 changed files with 11 additions and 21 deletions

View File

@ -145,9 +145,7 @@ def _copy_field(src, dst, field):
# the empty objects if the key didn't exist. # the empty objects if the key didn't exist.
sub_out_dict = dst sub_out_dict = dst
for sub_field in field: for sub_field in field:
if sub_field not in sub_out_dict: sub_out_dict = sub_out_dict.setdefault(sub_field, {})
sub_out_dict[sub_field] = {}
sub_out_dict = sub_out_dict[sub_field]
sub_out_dict[key_to_move] = sub_dict[key_to_move] sub_out_dict[key_to_move] = sub_dict[key_to_move]
@ -176,12 +174,10 @@ def only_fields(dictionary, fields):
split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields] split_fields = [SPLIT_FIELD_REGEX.split(f) for f in fields]
# for each element of the output array of arrays: # for each element of the output array of arrays:
# remove escaping so we can use the right key names. This purposefully avoids # remove escaping so we can use the right key names.
# using list comprehensions to avoid needless allocations as this may be called split_fields[:] = [
# on a lot of events. [f.replace(r'\.', r'.') for f in field_array] for field_array in split_fields
for field_array in split_fields: ]
for i, field in enumerate(field_array):
field_array[i] = field.replace(r'\.', r'.')
output = {} output = {}
for field_array in split_fields: for field_array in split_fields:
@ -258,8 +254,10 @@ def serialize_event(e, time_now_ms, as_client_event=True,
if as_client_event: if as_client_event:
d = event_format(d) d = event_format(d)
if (only_event_fields and isinstance(only_event_fields, list) and if only_event_fields:
all(isinstance(f, basestring) for f in only_event_fields)): if (not isinstance(only_event_fields, list) or
not all(isinstance(f, basestring) for f in only_event_fields)):
raise TypeError("only_event_fields must be a list of strings")
d = only_fields(d, only_event_fields) d = only_fields(d, only_event_fields)
return d return d

View File

@ -272,7 +272,7 @@ class SerializeEventTestCase(unittest.TestCase):
) )
def test_event_fields_fail_if_fields_not_str(self): def test_event_fields_fail_if_fields_not_str(self):
self.assertEquals( with self.assertRaises(TypeError):
self.serialize( self.serialize(
MockEvent( MockEvent(
room_id="!foo:bar", room_id="!foo:bar",
@ -281,12 +281,4 @@ class SerializeEventTestCase(unittest.TestCase):
}, },
), ),
["room_id", 4] ["room_id", 4]
),
{
"room_id": "!foo:bar",
"content": {
"foo": "bar",
},
"unsigned": {}
}
) )