gallery: implement AND/NOT for tag search

This commit is contained in:
Přemysl Eric Janouch 2024-01-22 19:29:51 +01:00
parent 4f174972e3
commit 083739fd4e
Signed by: p
GPG Key ID: A0420B94F92B9493
4 changed files with 120 additions and 25 deletions

View File

@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS tag_space(
id INTEGER NOT NULL, id INTEGER NOT NULL,
name TEXT NOT NULL, name TEXT NOT NULL,
description TEXT, description TEXT,
CHECK (name NOT LIKE '%:%'), CHECK (name NOT LIKE '%:%' AND name NOT LIKE '-%'),
PRIMARY KEY (id) PRIMARY KEY (id)
) STRICT; ) STRICT;

135
main.go
View File

@ -62,10 +62,33 @@ func hammingDistance(a, b int64) int {
return bits.OnesCount64(uint64(a) ^ uint64(b)) return bits.OnesCount64(uint64(a) ^ uint64(b))
} }
type productAggregator float64
func (pa *productAggregator) Step(v float64) {
*pa = productAggregator(float64(*pa) * v)
}
func (pa *productAggregator) Done() float64 {
return float64(*pa)
}
func newProductAggregator() *productAggregator {
pa := productAggregator(1)
return &pa
}
func init() { func init() {
sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{ sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{
ConnectHook: func(conn *sqlite3.SQLiteConn) error { ConnectHook: func(conn *sqlite3.SQLiteConn) error {
return conn.RegisterFunc("hamming", hammingDistance, true /*pure*/) if err := conn.RegisterFunc(
"hamming", hammingDistance, true /*pure*/); err != nil {
return err
}
if err := conn.RegisterAggregator(
"product", newProductAggregator, true /*pure*/); err != nil {
return err
}
return nil
}, },
}) })
} }
@ -956,17 +979,89 @@ func handleAPISimilar(w http.ResponseWriter, r *http.Request) {
} }
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
// This is the most miserable part of the whole program.
// NOTE: AND will mean MULTIPLY(IFNULL(ta.weight, 0)) per SHA1. const searchCTE1 = `WITH
const searchCTE = `WITH
matches(sha1, thumbw, thumbh, score) AS ( matches(sha1, thumbw, thumbh, score) AS (
SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score
FROM tag_assignment AS ta FROM tag_assignment AS ta
JOIN image AS i ON i.sha1 = ta.sha1 JOIN image AS i ON i.sha1 = ta.sha1
WHERE ta.tag = ? WHERE ta.tag = %d
) )
` `
const searchCTEMulti = `WITH
positive(tag) AS (VALUES %s),
candidates(sha1) AS (%s),
matches(sha1, thumbw, thumbh, score) AS (
SELECT i.sha1, i.thumbw, i.thumbh,
product(IFNULL(ta.weight, 0)) AS score
FROM image AS i, positive AS p
JOIN candidates AS c ON i.sha1 = c.sha1
LEFT JOIN tag_assignment AS ta ON ta.sha1 = i.sha1 AND ta.tag = p.tag
GROUP BY i.sha1
)
`
func parseQuery(query string) (string, error) {
positive, negative := []int64{}, []int64{}
for _, word := range strings.Split(query, " ") {
if word == "" {
continue
}
space, tag, _ := strings.Cut(word, ":")
negated := false
if strings.HasPrefix(space, "-") {
space = space[1:]
negated = true
}
var tagID int64
err := db.QueryRow(`
SELECT t.id FROM tag AS t
JOIN tag_space AS ts ON t.space = ts.id
WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
if err != nil {
return "", err
}
if negated {
negative = append(negative, tagID)
} else {
positive = append(positive, tagID)
}
}
// Don't return most of the database, and simplify the following builder.
if len(positive) == 0 {
return "", errors.New("search is too wide")
}
// Optimise single tag searches.
if len(positive) == 1 && len(negative) == 0 {
return fmt.Sprintf(searchCTE1, positive[0]), nil
}
values := fmt.Sprintf(`(%d)`, positive[0])
candidates := fmt.Sprintf(
`SELECT sha1 FROM tag_assignment WHERE tag = %d`, positive[0])
for _, tagID := range positive[1:] {
values += fmt.Sprintf(`, (%d)`, tagID)
candidates += fmt.Sprintf(` INTERSECT
SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
}
for _, tagID := range negative {
candidates += fmt.Sprintf(` EXCEPT
SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID)
}
return fmt.Sprintf(searchCTEMulti, values, candidates), nil
}
// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
type webTagMatch struct { type webTagMatch struct {
SHA1 string `json:"sha1"` SHA1 string `json:"sha1"`
ThumbW int64 `json:"thumbW"` ThumbW int64 `json:"thumbW"`
@ -974,10 +1069,10 @@ type webTagMatch struct {
Score float32 `json:"score"` Score float32 `json:"score"`
} }
func getTagMatches(tag int64) (matches []webTagMatch, err error) { func getTagMatches(cte string) (matches []webTagMatch, err error) {
rows, err := db.Query(searchCTE+` rows, err := db.Query(cte + `
SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score
FROM matches`, tag) FROM matches`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1001,13 +1096,13 @@ type webTagSupertag struct {
score float32 score float32
} }
func getTagSupertags(tag int64) (result map[int64]*webTagSupertag, err error) { func getTagSupertags(cte string) (result map[int64]*webTagSupertag, err error) {
rows, err := db.Query(searchCTE+` rows, err := db.Query(cte + `
SELECT DISTINCT ta.tag, ts.name, t.name SELECT DISTINCT ta.tag, ts.name, t.name
FROM tag_assignment AS ta FROM tag_assignment AS ta
JOIN matches AS m ON m.sha1 = ta.sha1 JOIN matches AS m ON m.sha1 = ta.sha1
JOIN tag AS t ON ta.tag = t.id JOIN tag AS t ON ta.tag = t.id
JOIN tag_space AS ts ON ts.id = t.space`, tag) JOIN tag_space AS ts ON ts.id = t.space`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1032,18 +1127,18 @@ type webTagRelated struct {
Score float32 `json:"score"` Score float32 `json:"score"`
} }
func getTagRelated(tag int64, matches int) ( func getTagRelated(cte string, matches int) (
result map[string][]webTagRelated, err error) { result map[string][]webTagRelated, err error) {
// Not sure if this level of efficiency is achievable directly in SQL. // Not sure if this level of efficiency is achievable directly in SQL.
supertags, err := getTagSupertags(tag) supertags, err := getTagSupertags(cte)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rows, err := db.Query(searchCTE+` rows, err := db.Query(cte + `
SELECT ta.tag, ta.weight SELECT ta.tag, ta.weight
FROM tag_assignment AS ta FROM tag_assignment AS ta
JOIN matches AS m ON m.sha1 = ta.sha1`, tag) JOIN matches AS m ON m.sha1 = ta.sha1`)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1084,13 +1179,7 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
Related map[string][]webTagRelated `json:"related"` Related map[string][]webTagRelated `json:"related"`
} }
space, tag, _ := strings.Cut(params.Query, ":") cte, err := parseQuery(params.Query)
var tagID int64
err := db.QueryRow(`
SELECT t.id FROM tag AS t
JOIN tag_space AS ts ON t.space = ts.id
WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
http.Error(w, err.Error(), http.StatusNotFound) http.Error(w, err.Error(), http.StatusNotFound)
return return
@ -1099,11 +1188,11 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) {
return return
} }
if result.Matches, err = getTagMatches(tagID); err != nil { if result.Matches, err = getTagMatches(cte); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
if result.Related, err = getTagRelated(tagID, if result.Related, err = getTagRelated(cte,
len(result.Matches)); err != nil { len(result.Matches)); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return

View File

@ -646,7 +646,11 @@ let Search = {
m(Header), m(Header),
m('.body', {}, [ m('.body', {}, [
m('.sidebar', [ m('.sidebar', [
m('p', SearchModel.query), m('input', {
value: SearchModel.query,
onchange: event => m.route.set(
`/search/:key`, {key: event.target.value}),
}),
m(SearchRelated), m(SearchRelated),
]), ]),
m(SearchView), m(SearchView),

View File

@ -27,6 +27,8 @@ a { color: inherit; }
.sidebar { padding: .25rem .5rem; background: var(--shade-color); .sidebar { padding: .25rem .5rem; background: var(--shade-color);
border-right: 1px solid #ccc; overflow: auto; border-right: 1px solid #ccc; overflow: auto;
min-width: 10rem; max-width: 20rem; flex-shrink: 0; } min-width: 10rem; max-width: 20rem; flex-shrink: 0; }
.sidebar input { width: 100%; box-sizing: border-box; margin: .5rem 0;
font-size: inherit; }
.sidebar h2 { margin: 0.5em 0 0.25em 0; padding: 0; font-size: 1.2rem; } .sidebar h2 { margin: 0.5em 0 0.25em 0; padding: 0; font-size: 1.2rem; }
.sidebar ul { margin: .5rem 0; padding: 0; } .sidebar ul { margin: .5rem 0; padding: 0; }