From 2329bb1555a5df543abe2a993d334c7616ab6cb7 Mon Sep 17 00:00:00 2001 From: Yaroslav Klyuyev Date: Mon, 16 Mar 2015 22:12:13 +0200 Subject: [PATCH] added utils.file_adapter utils.file_adapter should be used when source type may be file-like object or path to file. Used in FromCSVTablesGenerator and in FromTablesGenerator. Should resolve #6 --- pynames/from_tables_generator.py | 10 +++++----- pynames/tests/test_utils.py | 24 ++++++++++++++++++++++-- pynames/utils.py | 14 +++++++++++++- 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/pynames/from_tables_generator.py b/pynames/from_tables_generator.py index c5ad26f..1052b5d 100644 --- a/pynames/from_tables_generator.py +++ b/pynames/from_tables_generator.py @@ -13,7 +13,7 @@ from pynames import exceptions from pynames.base import BaseGenerator from pynames.names import Name from pynames.relations import GENDER, LANGUAGE, LANGUAGE_FORMS_LANGTH -from pynames.utils import is_file +from pynames.utils import file_adapter class Template(object): @@ -104,7 +104,7 @@ class FromTablesGenerator(BaseGenerator): error_msg = 'FromTablesGenerator: you must make subclass of FromTablesGenerator and define attribute SOURCE in it.' raise NotImplementedError(error_msg) - with open(source) as f: + with file_adapter(source) as f: data = json.load(f) self.native_language = data['native_language'] self.languages = set(data['languages']) @@ -237,7 +237,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): self.full_forms_for_languages = set() def load_settings(self, settings_source): - with (settings_source if is_file(settings_source) else open(settings_source)) as settings_file: + with file_adapter(settings_source) as settings_file: reader = unicodecsv.DictReader(settings_file, encoding='utf-8') for row in reader: new_native_language = row.get('native_language', '').strip() @@ -258,7 +258,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): def load_templates(self, templates_source): template_slugs = [] - with (templates_source if is_file(templates_source) else open(templates_source)) as templates_file: + with file_adapter(templates_source) as templates_file: reader = unicodecsv.DictReader(templates_file, encoding='utf-8') for row in reader: template_data = { @@ -275,7 +275,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): return template_slugs def load_tables(self, tables_source): - with (tables_source if is_file(tables_source) else open(tables_source)) as tables_file: + with file_adapter(tables_source) as tables_file: reader = unicodecsv.DictReader(tables_file, encoding='utf-8') slugs = set([fieldname.split(':')[0] for fieldname in reader.fieldnames]) for slug in slugs: diff --git a/pynames/tests/test_utils.py b/pynames/tests/test_utils.py index fc58a2c..c97c02e 100644 --- a/pynames/tests/test_utils.py +++ b/pynames/tests/test_utils.py @@ -1,9 +1,11 @@ # coding: utf-8 -import unittest +import os import tempfile +import unittest -from pynames.utils import is_file +from pynames.utils import is_file, file_adapter +import pynames try: from django.core.files import File @@ -27,3 +29,21 @@ class TestName(unittest.TestCase): self.assertTrue(is_file(UploadedFile('mock'))) self.assertTrue(is_file(File('mock'))) self.assertTrue(is_file(ContentFile('mock'))) + + def test_file_dapter(self): + root_dir = os.path.dirname(pynames.__file__) + + test_file_path = os.path.join(root_dir, 'tests', 'fixtures', 'test_from_list_generator.json') + + with open(test_file_path) as f: + target_content = f.read() + + with file_adapter(test_file_path) as f: + self.assertEqual(f.read(), target_content) + + django_file_object = ContentFile(target_content) + classic_file_object = open(test_file_path, 'r') + + for tested_file_object in [django_file_object, classic_file_object]: + with file_adapter(tested_file_object) as f: + self.assertEqual(f.read(), target_content) diff --git a/pynames/utils.py b/pynames/utils.py index a0e1c0b..5aa5cac 100644 --- a/pynames/utils.py +++ b/pynames/utils.py @@ -1,7 +1,8 @@ # coding: utf-8 -import os +import contextlib import importlib +import os import pynames @@ -54,3 +55,14 @@ def is_file(obj): + [any([callable(getattr(obj, method_name, None)) for method_name in ('next', '__iter__')])] ) + + +@contextlib.contextmanager +def file_adapter(file_or_path): + """Context manager that works similar to ``open(file_path)``but also accepts already openned file-like objects.""" + if is_file(file_or_path): + file_obj = file_or_path + else: + file_obj = open(file_or_path) + yield file_obj + file_obj.close()