[py3] Made GeoIP tests pass with Python 3

This commit is contained in:
Claude Paroz 2013-05-10 13:18:07 +02:00
parent 465a29abe0
commit 7b00d90208
3 changed files with 17 additions and 21 deletions

View File

@ -137,9 +137,6 @@ class GeoIP(object):
if not isinstance(query, six.string_types): if not isinstance(query, six.string_types):
raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__) raise TypeError('GeoIP query must be a string, not type %s' % type(query).__name__)
# GeoIP only takes ASCII-encoded strings.
query = query.encode('ascii')
# Extra checks for the existence of country and city databases. # Extra checks for the existence of country and city databases.
if city_or_country and not (self._country or self._city): if city_or_country and not (self._country or self._city):
raise GeoIPException('Invalid GeoIP country and city data files.') raise GeoIPException('Invalid GeoIP country and city data files.')
@ -148,8 +145,8 @@ class GeoIP(object):
elif city and not self._city: elif city and not self._city:
raise GeoIPException('Invalid GeoIP city data file: %s' % self._city_file) raise GeoIPException('Invalid GeoIP city data file: %s' % self._city_file)
# Return the query string back to the caller. # Return the query string back to the caller. GeoIP only takes bytestrings.
return query return force_bytes(query)
def city(self, query): def city(self, query):
""" """
@ -157,33 +154,33 @@ class GeoIP(object):
Fully Qualified Domain Name (FQDN). Some information in the dictionary Fully Qualified Domain Name (FQDN). Some information in the dictionary
may be undefined (None). may be undefined (None).
""" """
query = self._check_query(query, city=True) enc_query = self._check_query(query, city=True)
if ipv4_re.match(query): if ipv4_re.match(query):
# If an IP address was passed in # If an IP address was passed in
return GeoIP_record_by_addr(self._city, c_char_p(query)) return GeoIP_record_by_addr(self._city, c_char_p(enc_query))
else: else:
# If a FQDN was passed in. # If a FQDN was passed in.
return GeoIP_record_by_name(self._city, c_char_p(query)) return GeoIP_record_by_name(self._city, c_char_p(enc_query))
def country_code(self, query): def country_code(self, query):
"Returns the country code for the given IP Address or FQDN." "Returns the country code for the given IP Address or FQDN."
query = self._check_query(query, city_or_country=True) enc_query = self._check_query(query, city_or_country=True)
if self._country: if self._country:
if ipv4_re.match(query): if ipv4_re.match(query):
return GeoIP_country_code_by_addr(self._country, query) return GeoIP_country_code_by_addr(self._country, enc_query)
else: else:
return GeoIP_country_code_by_name(self._country, query) return GeoIP_country_code_by_name(self._country, enc_query)
else: else:
return self.city(query)['country_code'] return self.city(query)['country_code']
def country_name(self, query): def country_name(self, query):
"Returns the country name for the given IP Address or FQDN." "Returns the country name for the given IP Address or FQDN."
query = self._check_query(query, city_or_country=True) enc_query = self._check_query(query, city_or_country=True)
if self._country: if self._country:
if ipv4_re.match(query): if ipv4_re.match(query):
return GeoIP_country_name_by_addr(self._country, query) return GeoIP_country_name_by_addr(self._country, enc_query)
else: else:
return GeoIP_country_name_by_name(self._country, query) return GeoIP_country_name_by_name(self._country, enc_query)
else: else:
return self.city(query)['country_name'] return self.city(query)['country_name']

View File

@ -92,7 +92,7 @@ def check_string(result, func, cargs):
free(result) free(result)
else: else:
s = '' s = ''
return s return s.decode()
GeoIP_database_info = lgeoip.GeoIP_database_info GeoIP_database_info = lgeoip.GeoIP_database_info
GeoIP_database_info.restype = geoip_char_p GeoIP_database_info.restype = geoip_char_p
@ -100,7 +100,12 @@ GeoIP_database_info.errcheck = check_string
# String output routines. # String output routines.
def string_output(func): def string_output(func):
def _err_check(result, func, cargs):
if result:
return result.decode()
return result
func.restype = c_char_p func.restype = c_char_p
func.errcheck = _err_check
return func return func
GeoIP_country_code_by_addr = string_output(lgeoip.GeoIP_country_code_by_addr) GeoIP_country_code_by_addr = string_output(lgeoip.GeoIP_country_code_by_addr)

View File

@ -106,12 +106,6 @@ class GeoIPTest(unittest.TestCase):
d = g.city("www.osnabrueck.de") d = g.city("www.osnabrueck.de")
self.assertEqual('Osnabrück', d['city']) self.assertEqual('Osnabrück', d['city'])
def test06_unicode_query(self):
"Testing that GeoIP accepts unicode string queries, see #17059."
g = GeoIP()
d = g.country('whitehouse.gov')
self.assertEqual('US', d['country_code'])
def suite(): def suite():
s = unittest.TestSuite() s = unittest.TestSuite()