requested changes

This commit is contained in:
H. Shay 2024-09-11 15:05:11 -07:00
parent 1893b3c722
commit 61a0fd5e38
3 changed files with 19 additions and 29 deletions

View File

@ -47,17 +47,16 @@ export class NsfwProtection extends Protection {
if (event['type'] === 'm.room.message') { if (event['type'] === 'm.room.message') {
const content = event['content'] || {}; const content = event['content'] || {};
const msgtype = content['msgtype'] || 'm.text'; const msgtype = content['msgtype'] || 'm.text';
const formattedBody = content['formatted_body'] || ''; const isMedia = msgtype === 'm.image';
const isMedia = msgtype === 'm.image' || formattedBody.toLowerCase().includes('<img');
if (isMedia) { if (isMedia) {
const mxc = content["url"] const mxc = content["url"];
const image = await mjolnir.client.downloadContent(mxc) const image = await mjolnir.client.downloadContent(mxc);
const decodedImage = await node.decodeImage(image.data, 3); const decodedImage = await node.decodeImage(image.data, 3);
const predictions = await this.model.classify(decodedImage) const predictions = await this.model.classify(decodedImage);
for (const prediction of predictions) { for (const prediction of predictions) {
if (prediction["className"] === "Porn") { if (["Hentai", "Porn"].includes(prediction["className"])) {
if (prediction["probability"] > mjolnir.config.nsfwSensitivity) { if (prediction["probability"] > mjolnir.config.nsfwSensitivity) {
await mjolnir.managementRoomOutput.logMessage(LogLevel.INFO, "NSFWProtection", `Redacting ${event["event_id"]} for inappropriate content.`); await mjolnir.managementRoomOutput.logMessage(LogLevel.INFO, "NSFWProtection", `Redacting ${event["event_id"]} for inappropriate content.`);
try { try {
@ -67,18 +66,9 @@ export class NsfwProtection extends Protection {
} }
} }
} else if (prediction["className"] === "Hentai") {
if (prediction["probability"] > mjolnir.config.nsfwSensitivity) {
await mjolnir.managementRoomOutput.logMessage(LogLevel.INFO, "NSFWProtection", `Redacting ${event["event_id"]} for inappropriate content.`);
try {
mjolnir.client.redactEvent(roomId, event["event_id"])
} catch (err) {
await mjolnir.managementRoomOutput.logMessage(LogLevel.ERROR, "NSFWProtection", `There was an error redacting ${event["event_id"]}: ${err}`);
}
}
} }
} }
decodedImage.dispose() decodedImage.dispose();
} }
} }
} }

View File

@ -103,7 +103,7 @@ export class ProtectionManager {
} }
if (protection.enabled) { if (protection.enabled) {
if (protection.name === "NsfwProtection") { if (protection.name === "NsfwProtection") {
(protection as NsfwProtection).initialize() (protection as NsfwProtection).initialize();
} }
for (let roomId of this.mjolnir.protectedRoomsTracker.getProtectedRooms()) { for (let roomId of this.mjolnir.protectedRoomsTracker.getProtectedRooms()) {
await protection.startProtectingRoom(this.mjolnir, roomId); await protection.startProtectingRoom(this.mjolnir, roomId);

View File

@ -15,7 +15,7 @@ describe("Test: NSFW protection", function () {
room = await client.createRoom({ invite: [mjolnirId] }); room = await client.createRoom({ invite: [mjolnirId] });
await client.joinRoom(room); await client.joinRoom(room);
await client.joinRoom(this.config.managementRoom); await client.joinRoom(this.config.managementRoom);
await client.setUserPowerLevel(mjolnirId, room, 100) await client.setUserPowerLevel(mjolnirId, room, 100);
}) })
this.afterEach(async function () { this.afterEach(async function () {
await client.stop(); await client.stop();
@ -34,12 +34,12 @@ describe("Test: NSFW protection", function () {
return await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir enable NsfwProtection` }); return await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir enable NsfwProtection` });
}); });
const data = readFileSync('test_tree.jpg') const data = readFileSync('test_tree.jpg');
const mxc = await client.uploadContent(data, 'image/png') const mxc = await client.uploadContent(data, 'image/png');
let content = {"msgtype": "m.image", "body": "test.jpeg", "url": mxc} let content = {"msgtype": "m.image", "body": "test.jpeg", "url": mxc};
let imageMessage = await client.sendMessage(room, content) let imageMessage = await client.sendMessage(room, content);
await delay(500) await delay(500);
let processedImage = await client.getEvent(room, imageMessage); let processedImage = await client.getEvent(room, imageMessage);
assert.equal(Object.keys(processedImage.content).length, 3, "This event should not have been redacted"); assert.equal(Object.keys(processedImage.content).length, 3, "This event should not have been redacted");
}); });
@ -47,19 +47,19 @@ describe("Test: NSFW protection", function () {
it("Nsfw protection redacts nsfw images", async function() { it("Nsfw protection redacts nsfw images", async function() {
this.timeout(20000); this.timeout(20000);
// dial the sensitivity on the protection way up so that all images are flagged as NSFW // dial the sensitivity on the protection way up so that all images are flagged as NSFW
this.mjolnir.config.nsfwSensitivity = 0.0 this.mjolnir.config.nsfwSensitivity = 0.0;
await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir rooms add ${room}` }); await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir rooms add ${room}` });
await getFirstReaction(client, this.mjolnir.managementRoomId, '✅', async () => { await getFirstReaction(client, this.mjolnir.managementRoomId, '✅', async () => {
return await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir enable NsfwProtection` }); return await client.sendMessage(this.mjolnir.managementRoomId, { msgtype: 'm.text', body: `!mjolnir enable NsfwProtection` });
}); });
const data = readFileSync('test_tree.jpg') const data = readFileSync('test_tree.jpg');
const mxc = await client.uploadContent(data, 'image/png') const mxc = await client.uploadContent(data, 'image/png');
let content = {"msgtype": "m.image", "body": "test.jpeg", "url": mxc} let content = {"msgtype": "m.image", "body": "test.jpeg", "url": mxc};
let imageMessage = await client.sendMessage(room, content) let imageMessage = await client.sendMessage(room, content);
await delay(500) await delay(500);
let processedImage = await client.getEvent(room, imageMessage); let processedImage = await client.getEvent(room, imageMessage);
assert.equal(Object.keys(processedImage.content).length, 0, "This event should have been redacted"); assert.equal(Object.keys(processedImage.content).length, 0, "This event should have been redacted");
}); });