diff --git a/initialize.sql b/initialize.sql index 5a54a7f..292c50b 100644 --- a/initialize.sql +++ b/initialize.sql @@ -76,7 +76,7 @@ CREATE TABLE IF NOT EXISTS tag_space( id INTEGER NOT NULL, name TEXT NOT NULL, description TEXT, - CHECK (name NOT LIKE '%:%'), + CHECK (name NOT LIKE '%:%' AND name NOT LIKE '-%'), PRIMARY KEY (id) ) STRICT; diff --git a/main.go b/main.go index 7c9aceb..644d2d4 100644 --- a/main.go +++ b/main.go @@ -62,10 +62,33 @@ func hammingDistance(a, b int64) int { return bits.OnesCount64(uint64(a) ^ uint64(b)) } +type productAggregator float64 + +func (pa *productAggregator) Step(v float64) { + *pa = productAggregator(float64(*pa) * v) +} + +func (pa *productAggregator) Done() float64 { + return float64(*pa) +} + +func newProductAggregator() *productAggregator { + pa := productAggregator(1) + return &pa +} + func init() { sql.Register("sqlite3_custom", &sqlite3.SQLiteDriver{ ConnectHook: func(conn *sqlite3.SQLiteConn) error { - return conn.RegisterFunc("hamming", hammingDistance, true /*pure*/) + if err := conn.RegisterFunc( + "hamming", hammingDistance, true /*pure*/); err != nil { + return err + } + if err := conn.RegisterAggregator( + "product", newProductAggregator, true /*pure*/); err != nil { + return err + } + return nil }, }) } @@ -956,17 +979,89 @@ func handleAPISimilar(w http.ResponseWriter, r *http.Request) { } // - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +// This is the most miserable part of the whole program. -// NOTE: AND will mean MULTIPLY(IFNULL(ta.weight, 0)) per SHA1. -const searchCTE = `WITH +const searchCTE1 = `WITH matches(sha1, thumbw, thumbh, score) AS ( SELECT i.sha1, i.thumbw, i.thumbh, ta.weight AS score FROM tag_assignment AS ta JOIN image AS i ON i.sha1 = ta.sha1 - WHERE ta.tag = ? + WHERE ta.tag = %d ) ` +const searchCTEMulti = `WITH + positive(tag) AS (VALUES %s), + candidates(sha1) AS (%s), + matches(sha1, thumbw, thumbh, score) AS ( + SELECT i.sha1, i.thumbw, i.thumbh, + product(IFNULL(ta.weight, 0)) AS score + FROM image AS i, positive AS p + JOIN candidates AS c ON i.sha1 = c.sha1 + LEFT JOIN tag_assignment AS ta ON ta.sha1 = i.sha1 AND ta.tag = p.tag + GROUP BY i.sha1 + ) +` + +func parseQuery(query string) (string, error) { + positive, negative := []int64{}, []int64{} + for _, word := range strings.Split(query, " ") { + if word == "" { + continue + } + + space, tag, _ := strings.Cut(word, ":") + + negated := false + if strings.HasPrefix(space, "-") { + space = space[1:] + negated = true + } + + var tagID int64 + err := db.QueryRow(` + SELECT t.id FROM tag AS t + JOIN tag_space AS ts ON t.space = ts.id + WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID) + if err != nil { + return "", err + } + + if negated { + negative = append(negative, tagID) + } else { + positive = append(positive, tagID) + } + } + + // Don't return most of the database, and simplify the following builder. + if len(positive) == 0 { + return "", errors.New("search is too wide") + } + + // Optimise single tag searches. + if len(positive) == 1 && len(negative) == 0 { + return fmt.Sprintf(searchCTE1, positive[0]), nil + } + + values := fmt.Sprintf(`(%d)`, positive[0]) + candidates := fmt.Sprintf( + `SELECT sha1 FROM tag_assignment WHERE tag = %d`, positive[0]) + for _, tagID := range positive[1:] { + values += fmt.Sprintf(`, (%d)`, tagID) + candidates += fmt.Sprintf(` INTERSECT + SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID) + } + for _, tagID := range negative { + candidates += fmt.Sprintf(` EXCEPT + SELECT sha1 FROM tag_assignment WHERE tag = %d`, tagID) + } + + return fmt.Sprintf(searchCTEMulti, values, candidates), nil +} + +// - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + type webTagMatch struct { SHA1 string `json:"sha1"` ThumbW int64 `json:"thumbW"` @@ -974,10 +1069,10 @@ type webTagMatch struct { Score float32 `json:"score"` } -func getTagMatches(tag int64) (matches []webTagMatch, err error) { - rows, err := db.Query(searchCTE+` +func getTagMatches(cte string) (matches []webTagMatch, err error) { + rows, err := db.Query(cte + ` SELECT sha1, IFNULL(thumbw, 0), IFNULL(thumbh, 0), score - FROM matches`, tag) + FROM matches`) if err != nil { return nil, err } @@ -1001,13 +1096,13 @@ type webTagSupertag struct { score float32 } -func getTagSupertags(tag int64) (result map[int64]*webTagSupertag, err error) { - rows, err := db.Query(searchCTE+` +func getTagSupertags(cte string) (result map[int64]*webTagSupertag, err error) { + rows, err := db.Query(cte + ` SELECT DISTINCT ta.tag, ts.name, t.name FROM tag_assignment AS ta JOIN matches AS m ON m.sha1 = ta.sha1 JOIN tag AS t ON ta.tag = t.id - JOIN tag_space AS ts ON ts.id = t.space`, tag) + JOIN tag_space AS ts ON ts.id = t.space`) if err != nil { return nil, err } @@ -1032,18 +1127,18 @@ type webTagRelated struct { Score float32 `json:"score"` } -func getTagRelated(tag int64, matches int) ( +func getTagRelated(cte string, matches int) ( result map[string][]webTagRelated, err error) { // Not sure if this level of efficiency is achievable directly in SQL. - supertags, err := getTagSupertags(tag) + supertags, err := getTagSupertags(cte) if err != nil { return nil, err } - rows, err := db.Query(searchCTE+` + rows, err := db.Query(cte + ` SELECT ta.tag, ta.weight FROM tag_assignment AS ta - JOIN matches AS m ON m.sha1 = ta.sha1`, tag) + JOIN matches AS m ON m.sha1 = ta.sha1`) if err != nil { return nil, err } @@ -1084,13 +1179,7 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) { Related map[string][]webTagRelated `json:"related"` } - space, tag, _ := strings.Cut(params.Query, ":") - - var tagID int64 - err := db.QueryRow(` - SELECT t.id FROM tag AS t - JOIN tag_space AS ts ON t.space = ts.id - WHERE ts.name = ? AND t.name = ?`, space, tag).Scan(&tagID) + cte, err := parseQuery(params.Query) if errors.Is(err, sql.ErrNoRows) { http.Error(w, err.Error(), http.StatusNotFound) return @@ -1099,11 +1188,11 @@ func handleAPISearch(w http.ResponseWriter, r *http.Request) { return } - if result.Matches, err = getTagMatches(tagID); err != nil { + if result.Matches, err = getTagMatches(cte); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } - if result.Related, err = getTagRelated(tagID, + if result.Related, err = getTagRelated(cte, len(result.Matches)); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/public/gallery.js b/public/gallery.js index 9d3b067..01439f7 100644 --- a/public/gallery.js +++ b/public/gallery.js @@ -646,7 +646,11 @@ let Search = { m(Header), m('.body', {}, [ m('.sidebar', [ - m('p', SearchModel.query), + m('input', { + value: SearchModel.query, + onchange: event => m.route.set( + `/search/:key`, {key: event.target.value}), + }), m(SearchRelated), ]), m(SearchView), diff --git a/public/style.css b/public/style.css index 1bdeb3f..7fd0079 100644 --- a/public/style.css +++ b/public/style.css @@ -27,6 +27,8 @@ a { color: inherit; } .sidebar { padding: .25rem .5rem; background: var(--shade-color); border-right: 1px solid #ccc; overflow: auto; min-width: 10rem; max-width: 20rem; flex-shrink: 0; } +.sidebar input { width: 100%; box-sizing: border-box; margin: .5rem 0; + font-size: inherit; } .sidebar h2 { margin: 0.5em 0 0.25em 0; padding: 0; font-size: 1.2rem; } .sidebar ul { margin: .5rem 0; padding: 0; }