diff --git a/onionshare/__init__.py b/onionshare/__init__.py index 86f03b84..52226b48 100644 --- a/onionshare/__init__.py +++ b/onionshare/__init__.py @@ -21,7 +21,7 @@ along with this program. If not, see . import os, sys, time, argparse, threading from . import strings -from .common import Common +from .common import Common, DownloadsDirErrorCannotCreate, DownloadsDirErrorNotWritable from .web import Web from .onion import * from .onionshare import OnionShare @@ -92,17 +92,19 @@ def main(cwd=None): # In receive mode, validate downloads dir if receive: valid = True - if not os.path.isdir(common.settings.get('downloads_dir')): - try: - os.mkdir(common.settings.get('downloads_dir'), 0o700) - except: - print(strings._('error_cannot_create_downloads_dir').format(common.settings.get('downloads_dir'))) - valid = False - if valid and not os.access(common.settings.get('downloads_dir'), os.W_OK): + try: + common.validate_downloads_dir() + + except DownloadsDirErrorCannotCreate: + print(strings._('error_cannot_create_downloads_dir').format(common.settings.get('downloads_dir'))) + valid = False + + except DownloadsDirErrorNotWritable: print(strings._('error_downloads_dir_not_writable').format(common.settings.get('downloads_dir'))) valid = False - if not valid: - sys.exit() + + if not valid: + sys.exit() # Create the Web object web = Web(common, False, receive) diff --git a/onionshare/common.py b/onionshare/common.py index 628064df..ad2f4574 100644 --- a/onionshare/common.py +++ b/onionshare/common.py @@ -31,6 +31,21 @@ import time from .settings import Settings + +class DownloadsDirErrorCannotCreate(Exception): + """ + Error creating the downloads dir (~/OnionShare by default). + """ + pass + + +class DownloadsDirErrorNotWritable(Exception): + """ + Downloads dir is not writable. + """ + pass + + class Common(object): """ The Common object is shared amongst all parts of OnionShare. @@ -321,6 +336,19 @@ class Common(object): }""" } + def validate_downloads_dir(self): + """ + Validate that downloads_dir exists, and create it if it doesn't + """ + if not os.path.isdir(self.settings.get('downloads_dir')): + try: + os.mkdir(self.settings.get('downloads_dir'), 0o700) + except: + raise DownloadsDirErrorCannotCreate + + if not os.access(self.settings.get('downloads_dir'), os.W_OK): + raise DownloadsDirErrorNotWritable + @staticmethod def random_string(num_bytes, output_len=None): """