From 16ef6d19c316000aa19167567aaa5543d44712ed Mon Sep 17 00:00:00 2001 From: Yaroslav Klyuyev Date: Mon, 16 Feb 2015 21:02:28 +0200 Subject: [PATCH] FromCSVTablesGenerator.load_*() now accepts both path to file and file object as argument --- pynames/from_tables_generator.py | 13 +++++++------ pynames/utils.py | 9 +++++++++ 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/pynames/from_tables_generator.py b/pynames/from_tables_generator.py index 51d6380..c5ad26f 100644 --- a/pynames/from_tables_generator.py +++ b/pynames/from_tables_generator.py @@ -9,10 +9,11 @@ from collections import Iterable import unicodecsv # pynames: -from pynames.relations import GENDER, LANGUAGE, LANGUAGE_FORMS_LANGTH -from pynames.names import Name -from pynames.base import BaseGenerator from pynames import exceptions +from pynames.base import BaseGenerator +from pynames.names import Name +from pynames.relations import GENDER, LANGUAGE, LANGUAGE_FORMS_LANGTH +from pynames.utils import is_file class Template(object): @@ -236,7 +237,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): self.full_forms_for_languages = set() def load_settings(self, settings_source): - with open(settings_source) as settings_file: + with (settings_source if is_file(settings_source) else open(settings_source)) as settings_file: reader = unicodecsv.DictReader(settings_file, encoding='utf-8') for row in reader: new_native_language = row.get('native_language', '').strip() @@ -257,7 +258,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): def load_templates(self, templates_source): template_slugs = [] - with open(templates_source) as templates_file: + with (templates_source if is_file(templates_source) else open(templates_source)) as templates_file: reader = unicodecsv.DictReader(templates_file, encoding='utf-8') for row in reader: template_data = { @@ -274,7 +275,7 @@ class FromCSVTablesGenerator(FromTablesGenerator): return template_slugs def load_tables(self, tables_source): - with open(tables_source) as tables_file: + with (tables_source if is_file(tables_source) else open(tables_source)) as tables_file: reader = unicodecsv.DictReader(tables_file, encoding='utf-8') slugs = set([fieldname.split(':')[0] for fieldname in reader.fieldnames]) for slug in slugs: diff --git a/pynames/utils.py b/pynames/utils.py index afa71cf..92178d9 100644 --- a/pynames/utils.py +++ b/pynames/utils.py @@ -41,3 +41,12 @@ def get_all_generators(): generators.append(generator) return generators + + +def is_file(obj): + """Retrun True is object has 'next', '__enter__' and '__exit__' methods. + + Suitable to check both builtin ``file`` and ``django.core.file.File`` instances. + + """ + return all([callable(getattr(obj, method_name, None)) for method_name in ('next', '__enter__', '__exit__')])