Fixed #35100 -- Reworked GeoIP2 database initialization.

This commit is contained in:
Nick Pope 2021-04-06 09:52:09 +01:00 committed by Mariusz Felisiak
parent 8d2c16252e
commit 40b5b1596f
3 changed files with 75 additions and 131 deletions

View File

@ -21,6 +21,7 @@ from django.core.exceptions import ValidationError
from django.core.validators import validate_ipv46_address
from django.utils._os import to_path
from django.utils.deprecation import RemovedInDjango60Warning
from django.utils.functional import cached_property
__all__ = ["HAS_GEOIP2"]
@ -53,13 +54,8 @@ class GeoIP2:
(MODE_AUTO, MODE_MMAP_EXT, MODE_MMAP, MODE_FILE, MODE_MEMORY)
)
# Paths to the city & country binary databases.
_city_file = ""
_country_file = ""
# Initially, pointers to GeoIP file references are NULL.
_city = None
_country = None
_path = None
_reader = None
def __init__(self, path=None, cache=0, country=None, city=None):
"""
@ -84,114 +80,69 @@ class GeoIP2:
* city: The name of the GeoIP city data file. Defaults to
'GeoLite2-City.mmdb'; overrides the GEOIP_CITY setting.
"""
# Checking the given cache option.
if cache not in self.cache_options:
raise GeoIP2Exception("Invalid GeoIP caching option: %s" % cache)
# Getting the GeoIP data path.
path = path or getattr(settings, "GEOIP_PATH", None)
city = city or getattr(settings, "GEOIP_CITY", "GeoLite2-City.mmdb")
country = country or getattr(settings, "GEOIP_COUNTRY", "GeoLite2-Country.mmdb")
if not path:
raise GeoIP2Exception(
"GeoIP path must be provided via parameter or the GEOIP_PATH setting."
)
path = to_path(path)
if path.is_dir():
# Constructing the GeoIP database filenames using the settings
# dictionary. If the database files for the GeoLite country
# and/or city datasets exist, then try to open them.
country_db = path / (
country or getattr(settings, "GEOIP_COUNTRY", "GeoLite2-Country.mmdb")
)
if country_db.is_file():
self._country = geoip2.database.Reader(str(country_db), mode=cache)
self._country_file = country_db
city_db = path / (
city or getattr(settings, "GEOIP_CITY", "GeoLite2-City.mmdb")
)
if city_db.is_file():
self._city = geoip2.database.Reader(str(city_db), mode=cache)
self._city_file = city_db
if not self._reader:
raise GeoIP2Exception("Could not load a database from %s." % path)
elif path.is_file():
# Otherwise, some detective work will be needed to figure out
# whether the given database path is for the GeoIP country or city
# databases.
reader = geoip2.database.Reader(str(path), mode=cache)
db_type = reader.metadata().database_type
if "City" in db_type:
# GeoLite City database detected.
self._city = reader
self._city_file = path
elif "Country" in db_type:
# GeoIP Country database detected.
self._country = reader
self._country_file = path
else:
raise GeoIP2Exception(
"Unable to recognize database edition: %s" % db_type
)
# Try the path first in case it is the full path to a database.
for path in (path, path / city, path / country):
if path.is_file():
self._path = path
self._reader = geoip2.database.Reader(path, mode=cache)
break
else:
raise GeoIP2Exception("GeoIP path must be a valid file or directory.")
raise GeoIP2Exception(
"Path must be a valid database or directory containing databases."
)
@property
def _reader(self):
return self._country or self._city
@property
def _country_or_city(self):
if self._country:
return self._country.country
else:
return self._city.city
database_type = self._metadata.database_type
if not database_type.endswith(("City", "Country")):
raise GeoIP2Exception(f"Unable to handle database edition: {database_type}")
def __del__(self):
# Cleanup any GeoIP file handles lying around.
if self._city:
self._city.close()
if self._country:
self._country.close()
if self._reader:
self._reader.close()
def __repr__(self):
meta = self._reader.metadata()
version = "[v%s.%s]" % (
meta.binary_format_major_version,
meta.binary_format_minor_version,
)
return (
'<%(cls)s %(version)s _country_file="%(country)s", _city_file="%(city)s">'
% {
"cls": self.__class__.__name__,
"version": version,
"country": self._country_file,
"city": self._city_file,
}
)
m = self._metadata
version = f"v{m.binary_format_major_version}.{m.binary_format_minor_version}"
return f"<{self.__class__.__name__} [{version}] _path='{self._path}'>"
def _check_query(self, query, city=False, city_or_country=False):
"Check the query and database availability."
@cached_property
def _metadata(self):
return self._reader.metadata()
def _query(self, query, *, require_city=False):
if not isinstance(query, (str, ipaddress.IPv4Address, ipaddress.IPv6Address)):
raise TypeError(
"GeoIP query must be a string or instance of IPv4Address or "
"IPv6Address, not type %s" % type(query).__name__,
)
# Extra checks for the existence of country and city databases.
if city_or_country and not (self._country or self._city):
raise GeoIP2Exception("Invalid GeoIP country and city data files.")
elif city and not self._city:
raise GeoIP2Exception("Invalid GeoIP city data file: %s" % self._city_file)
is_city = self._metadata.database_type.endswith("City")
if require_city and not is_city:
raise GeoIP2Exception(f"Invalid GeoIP city data file: {self._path}")
# Return the query string back to the caller. GeoIP2 only takes IP addresses.
try:
validate_ipv46_address(query)
except ValidationError:
# GeoIP2 only takes IP addresses, so try to resolve a hostname.
query = socket.gethostbyname(query)
return query
function = self._reader.city if is_city else self._reader.country
return function(query)
def city(self, query):
"""
@ -199,8 +150,7 @@ class GeoIP2:
Fully Qualified Domain Name (FQDN). Some information in the dictionary
may be undefined (None).
"""
enc_query = self._check_query(query, city=True)
response = self._city.city(enc_query)
response = self._query(query, require_city=True)
region = response.subdivisions[0] if response.subdivisions else None
return {
"accuracy_radius": response.location.accuracy_radius,
@ -236,9 +186,7 @@ class GeoIP2:
IP address or a Fully Qualified Domain Name (FQDN). For example, both
'24.124.1.80' and 'djangoproject.com' are valid parameters.
"""
# Returning the country code and name
enc_query = self._check_query(query, city_or_country=True)
response = self._country_or_city(enc_query)
response = self._query(query, require_city=False)
return {
"continent_code": response.continent.code,
"continent_name": response.continent.name,

View File

@ -305,6 +305,13 @@ backends.
* Support for GDAL 2.4 is removed.
* :class:`~django.contrib.gis.geoip2.GeoIP2` no longer opens both city and
country databases when a directory path is provided, preferring the city
database, if it is available. The country database is a subset of the city
database and both are not typically needed. If you require use of the country
database when in the same directory as the city database, explicitly pass the
country database path to the constructor.
Dropped support for MariaDB 10.4
--------------------------------

View File

@ -50,17 +50,11 @@ class GeoLite2Test(SimpleTestCase):
g2 = GeoIP2(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
# Path provided as a string.
g3 = GeoIP2(str(settings.GEOIP_PATH))
for g in (g1, g2, g3):
self.assertTrue(g._country)
self.assertTrue(g._city)
# Only passing in the location of one database.
g4 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_CITY, country="")
self.assertTrue(g4._city)
self.assertIsNone(g4._country)
g5 = GeoIP2(settings.GEOIP_PATH / settings.GEOIP_COUNTRY, city="")
self.assertTrue(g5._country)
self.assertIsNone(g5._city)
for g in (g1, g2, g3, g4, g5):
self.assertTrue(g._reader)
# Improper parameters.
bad_params = (23, "foo", 15.23)
@ -76,7 +70,7 @@ class GeoLite2Test(SimpleTestCase):
def test_no_database_file(self):
invalid_path = pathlib.Path(__file__).parent.joinpath("data/invalid").resolve()
msg = f"Could not load a database from {invalid_path}."
msg = "Path must be a valid database or directory containing databases."
with self.assertRaisesMessage(GeoIP2Exception, msg):
GeoIP2(invalid_path)
@ -103,6 +97,25 @@ class GeoLite2Test(SimpleTestCase):
def test_country(self):
g = GeoIP2(city="<invalid>")
self.assertIs(g._metadata.database_type.endswith("Country"), True)
for query in self.query_values:
with self.subTest(query=query):
self.assertEqual(
g.country(query),
{
"continent_code": "EU",
"continent_name": "Europe",
"country_code": "GB",
"country_name": "United Kingdom",
"is_in_european_union": False,
},
)
self.assertEqual(g.country_code(query), "GB")
self.assertEqual(g.country_name(query), "United Kingdom")
def test_country_using_city_database(self):
g = GeoIP2(country="<invalid>")
self.assertIs(g._metadata.database_type.endswith("City"), True)
for query in self.query_values:
with self.subTest(query=query):
self.assertEqual(
@ -120,6 +133,7 @@ class GeoLite2Test(SimpleTestCase):
def test_city(self):
g = GeoIP2(country="<invalid>")
self.assertIs(g._metadata.database_type.endswith("City"), True)
for query in self.query_values:
with self.subTest(query=query):
self.assertEqual(
@ -179,40 +193,16 @@ class GeoLite2Test(SimpleTestCase):
def test_del(self):
g = GeoIP2()
city = g._city
country = g._country
self.assertIs(city._db_reader.closed, False)
self.assertIs(country._db_reader.closed, False)
reader = g._reader
self.assertIs(reader._db_reader.closed, False)
del g
self.assertIs(city._db_reader.closed, True)
self.assertIs(country._db_reader.closed, True)
self.assertIs(reader._db_reader.closed, True)
def test_repr(self):
g = GeoIP2()
meta = g._reader.metadata()
version = "%s.%s" % (
meta.binary_format_major_version,
meta.binary_format_minor_version,
)
country_path = g._country_file
city_path = g._city_file
expected = (
'<GeoIP2 [v%(version)s] _country_file="%(country)s", _city_file="%(city)s">'
% {
"version": version,
"country": country_path,
"city": city_path,
}
)
self.assertEqual(repr(g), expected)
def test_check_query(self):
g = GeoIP2()
self.assertEqual(g._check_query(self.fqdn), self.ipv4_str)
self.assertEqual(g._check_query(self.ipv4_str), self.ipv4_str)
self.assertEqual(g._check_query(self.ipv6_str), self.ipv6_str)
self.assertEqual(g._check_query(self.ipv4_addr), self.ipv4_addr)
self.assertEqual(g._check_query(self.ipv6_addr), self.ipv6_addr)
m = g._metadata
version = f"{m.binary_format_major_version}.{m.binary_format_minor_version}"
self.assertEqual(repr(g), f"<GeoIP2 [v{version}] _path='{g._path}'>")
def test_coords_deprecation_warning(self):
g = GeoIP2()
@ -226,8 +216,7 @@ class GeoLite2Test(SimpleTestCase):
msg = "GeoIP2.open() is deprecated. Use GeoIP2() instead."
with self.assertWarnsMessage(RemovedInDjango60Warning, msg):
g = GeoIP2.open(settings.GEOIP_PATH, GeoIP2.MODE_AUTO)
self.assertTrue(g._country)
self.assertTrue(g._city)
self.assertTrue(g._reader)
@skipUnless(HAS_GEOIP2, "GeoIP2 is required.")
@ -248,7 +237,7 @@ class ErrorTest(SimpleTestCase):
GeoIP2()
def test_unsupported_database(self):
msg = "Unable to recognize database edition: GeoLite2-ASN"
msg = "Unable to handle database edition: GeoLite2-ASN"
with self.settings(GEOIP_PATH=build_geoip_path("GeoLite2-ASN-Test.mmdb")):
with self.assertRaisesMessage(GeoIP2Exception, msg):
GeoIP2()