Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 97 additions & 13 deletions country/code_fetcher.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package country

import (
"maps"
"regexp"
"slices"
"strings"
"sync"
"unicode"

uslices "github.com/upfluence/pkg/slices"
"github.com/upfluence/pkg/v2/stringutil"
)

Expand Down Expand Up @@ -40,8 +45,50 @@ var (
}
)

type SearchOperator func(key, searchTerm string) bool

func SearchOperatorContains(key, searchTerm string) bool {
return strings.Contains(strings.ToLower(key), strings.ToLower(searchTerm))
}

func SearchOperatorMatchBoolPrefix(key, searchTerm string) bool {
words := strings.FieldsFunc(searchTerm, unicode.IsSpace)

if len(words) == 0 {
return true
}

for _, word := range words[:len(words)-1] {
// We use `(?i)` to make the match case-insensitive
// We use `regexp.QuoteMeta` to escape any special characters in the word
pattern := `(?i)\b` + regexp.QuoteMeta(word) + `\b`
matched, err := regexp.MatchString(pattern, key)

if err != nil {
// In the unlikely event of a regex compilation error, return false
return false
}

if !matched {
return false
}
}

regexp.QuoteMeta(words[len(words)-1])
pattern := `(?i)\b` + regexp.QuoteMeta(words[len(words)-1])
matched, err := regexp.MatchString(pattern, key)

if err != nil {
// In the unlikely event of a regex compilation error, return false
return false
}

return matched
}

type CodeFetcher interface {
Fetch(string) (CountryCode, bool)
Search(string, SearchOperator) []CountryCode
}

type IndexedCodeFetcher struct {
Expand All @@ -60,25 +107,46 @@ func (icf *IndexedCodeFetcher) Fetch(key string) (CountryCode, bool) {
return CountryCode{}, false
}

icf.once.Do(func() {
ccs := icf.CountryCodes
icf.once.Do(icf.prepareIndex)

if ccs == nil {
ccs = DefaultCountryCodes
}
cc, ok := icf.indexedCountryCodes[icf.NormalizeKey(key)]

icf.indexedCountryCodes = make(map[string]CountryCode, len(ccs))
return cc, ok
}

func (icf *IndexedCodeFetcher) Search(searchTerm string, operator SearchOperator) []CountryCode {
if searchTerm == "" {
return nil
}

for _, cc := range ccs {
for _, k := range icf.ExtractKeys(cc) {
icf.indexedCountryCodes[icf.NormalizeKey(k)] = cc
icf.once.Do(icf.prepareIndex)

return uslices.Reduce(
slices.Collect(maps.Keys(icf.indexedCountryCodes)),
func(acc []CountryCode, key string) []CountryCode {
if operator(key, searchTerm) {
return append(acc, icf.indexedCountryCodes[key])
}
}
})

cc, ok := icf.indexedCountryCodes[icf.NormalizeKey(key)]
return acc
},
)
}

return cc, ok
func (icf *IndexedCodeFetcher) prepareIndex() {
ccs := icf.CountryCodes

if ccs == nil {
ccs = DefaultCountryCodes
}

icf.indexedCountryCodes = make(map[string]CountryCode, len(ccs))

for _, cc := range ccs {
for _, k := range icf.ExtractKeys(cc) {
icf.indexedCountryCodes[icf.NormalizeKey(k)] = cc
}
}
}

type MultiCodeFetcher []CodeFetcher
Expand All @@ -96,3 +164,19 @@ func (cfs MultiCodeFetcher) Fetch(key string) (CountryCode, bool) {

return CountryCode{}, false
}

func (cfs MultiCodeFetcher) Search(key string, operator SearchOperator) []CountryCode {
if key == "" {
return nil
}

countryCodeByAlpha2 := map[string]CountryCode{}

for _, cf := range cfs {
for _, cc := range cf.Search(key, operator) {
countryCodeByAlpha2[cc.Alpha2] = cc
}
}

return slices.Collect(maps.Values(countryCodeByAlpha2))
}
82 changes: 71 additions & 11 deletions country/code_fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,18 @@ import (
"github.com/stretchr/testify/assert"
)

func TestCodeFetcher(t *testing.T) {
var (
zeroValue CountryCode
func mustFetch(k string) CountryCode {
cc, ok := Alpha2CodeFetcher.Fetch(k)

mustFetch = func(k string) CountryCode {
cc, ok := Alpha2CodeFetcher.Fetch(k)
if !ok {
panic("not found")
}

if !ok {
panic("not found")
}
return cc
}

return cc
}
)
func TestCodeFetcher_Fetch(t *testing.T) {
var zeroValue CountryCode

for _, tt := range []struct {
k string
Expand Down Expand Up @@ -71,3 +69,65 @@ func TestCodeFetcher(t *testing.T) {
}
}
}

func TestCodeFetcher_Search(t *testing.T) {
for _, tt := range []struct {
searchTerm string
searchOperatior SearchOperator
want []CountryCode
}{
{
searchTerm: "United",
searchOperatior: SearchOperatorContains,
want: []CountryCode{
mustFetch("US"),
mustFetch("UM"),
mustFetch("AE"),
mustFetch("TZ"),
mustFetch("UK"),
},
},
{
searchTerm: "fRa",
searchOperatior: SearchOperatorContains,
want: []CountryCode{
mustFetch("FR"),
mustFetch("FX"),
},
},
{
searchTerm: "DE",
searchOperatior: SearchOperatorContains,
want: []CountryCode{
mustFetch("CD"),
mustFetch("FM"),
mustFetch("GP"),
mustFetch("RU"),
mustFetch("DE"),
mustFetch("LA"),
mustFetch("CV"),
mustFetch("BD"),
mustFetch("KP"),
mustFetch("DK"),
mustFetch("SE"),
},
},
{
searchTerm: "States Mi",
searchOperatior: SearchOperatorContains,
want: []CountryCode{
mustFetch("UM"),
},
},
{
searchTerm: "States Mi",
searchOperatior: SearchOperatorMatchBoolPrefix,
want: []CountryCode{
mustFetch("UM"), // "United States Minor Outlying Islands"
mustFetch("FM"), // "Micronesia, Federated States of"
},
},
} {
assert.ElementsMatch(t, tt.want, DefaultCodeFetcher.Search(tt.searchTerm, tt.searchOperatior))
}
}
Loading
Loading