1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16
17 package search
18
19 import (
20 "bytes"
21 "context"
22 "encoding/json"
23 "errors"
24 "log"
25 "net/http"
26 "os"
27 "strconv"
28 "sync"
29 "time"
30
31 "github.com/gorilla/websocket"
32
33 "perkeep.org/pkg/schema"
34 )
35
36 const (
37
38 writeWait = 10 * time.Second
39
40
41 pongWait = 60 * time.Second
42
43
44 pingPeriod = (pongWait * 9) / 10
45
46
47 maxMessageSize = 10 << 10
48 )
49
50 var debug, _ = strconv.ParseBool(os.Getenv("CAMLI_DEBUG"))
51
52 type wsHub struct {
53 sh *Handler
54 register chan *wsConn
55 unregister chan *wsConn
56 watchReq chan watchReq
57 newBlobRecv chan schema.CamliType
58 updatedResults chan *watchedQuery
59 statusUpdate chan json.RawMessage
60
61
62 conns map[*wsConn]bool
63 }
64
65 func newWebsocketHub(sh *Handler) *wsHub {
66 return &wsHub{
67 sh: sh,
68 register: make(chan *wsConn),
69 unregister: make(chan *wsConn),
70 conns: make(map[*wsConn]bool),
71 watchReq: make(chan watchReq, buffered),
72 newBlobRecv: make(chan schema.CamliType, buffered),
73 updatedResults: make(chan *watchedQuery, buffered),
74 statusUpdate: make(chan json.RawMessage, buffered),
75 }
76 }
77
78 func (h *wsHub) run() {
79 var lastStatusMsg []byte
80 for {
81 select {
82 case st := <-h.statusUpdate:
83 const prefix = `{"tag":"_status","status":`
84 lastStatusMsg = make([]byte, 0, len(prefix)+len(st)+1)
85 lastStatusMsg = append(lastStatusMsg, prefix...)
86 lastStatusMsg = append(lastStatusMsg, st...)
87 lastStatusMsg = append(lastStatusMsg, '}')
88 for c := range h.conns {
89 c.send <- lastStatusMsg
90 }
91 case c := <-h.register:
92 h.conns[c] = true
93 c.send <- lastStatusMsg
94 case c := <-h.unregister:
95 delete(h.conns, c)
96 close(c.send)
97 case camliType := <-h.newBlobRecv:
98 if camliType == "" {
99
100
101
102
103
104
105 continue
106 }
107
108 for conn := range h.conns {
109 for _, wq := range conn.queries {
110 go h.redoSearch(wq)
111 }
112 }
113 case wr := <-h.watchReq:
114
115 if wr.q == nil {
116 delete(wr.conn.queries, wr.tag)
117 log.Printf("Removed subscription for %v, %q", wr.conn, wr.tag)
118 continue
119 }
120
121
122 wq := &watchedQuery{
123 conn: wr.conn,
124 tag: wr.tag,
125 q: wr.q,
126 }
127 wr.conn.queries[wr.tag] = wq
128 if debug {
129 log.Printf("websocket: added/updated search subscription for tag %q", wr.tag)
130 }
131 go h.doSearch(wq)
132
133 case wq := <-h.updatedResults:
134 if !h.conns[wq.conn] || wq.conn.queries[wq.tag] == nil {
135
136 continue
137 }
138 wq.mu.Lock()
139 lastres := wq.lastres
140 wq.mu.Unlock()
141 resb, err := json.Marshal(wsUpdateMessage{
142 Tag: wq.tag,
143 Result: lastres,
144 })
145 if err != nil {
146 panic(err)
147 }
148 wq.conn.send <- resb
149 }
150 }
151 }
152
153
154
155
156
157 func (h *wsHub) redoSearch(wq *watchedQuery) {
158 wq.mu.Lock()
159 defer wq.mu.Unlock()
160 wq.dirty = true
161 if wq.refreshing {
162
163
164 return
165 }
166 for wq.dirty {
167 wq.refreshing = true
168 wq.dirty = false
169 wq.mu.Unlock()
170 h.doSearch(wq)
171 wq.mu.Lock()
172 }
173 wq.refreshing = false
174 }
175
176 func (h *wsHub) doSearch(wq *watchedQuery) {
177
178 q := new(SearchQuery)
179 *q = *wq.q
180 if q.Describe != nil {
181 q.Describe = wq.q.Describe.Clone()
182 }
183
184 res, err := h.sh.Query(context.TODO(), q)
185 if err != nil {
186 log.Printf("Query error: %v", err)
187 return
188 }
189 resj, _ := json.Marshal(res)
190
191 wq.mu.Lock()
192 eq := bytes.Equal(wq.lastresj, resj)
193 wq.lastres = res
194 wq.lastresj = resj
195 wq.mu.Unlock()
196 if eq {
197
198 return
199 }
200 h.updatedResults <- wq
201 }
202
203 type wsConn struct {
204 ws *websocket.Conn
205 send chan []byte
206 sh *Handler
207
208
209 queries map[string]*watchedQuery
210 }
211
212 type watchedQuery struct {
213 conn *wsConn
214 tag string
215 q *SearchQuery
216
217 mu sync.Mutex
218 refreshing bool
219 dirty bool
220 lastres *SearchResult
221 lastresj []byte
222 }
223
224
225 type watchReq struct {
226 conn *wsConn
227 tag string
228 q *SearchQuery
229 }
230
231
232 type wsClientMessage struct {
233
234 Tag string `json:"tag"`
235
236 Query *SearchQuery `json:"query,omitempty"`
237 }
238
239 type wsUpdateMessage struct {
240 Tag string `json:"tag"`
241 Result *SearchResult `json:"result,omitempty"`
242 }
243
244
245 func (c *wsConn) readPump() {
246 defer func() {
247 c.sh.wsHub.unregister <- c
248 c.ws.Close()
249 }()
250 c.ws.SetReadLimit(maxMessageSize)
251 c.ws.SetReadDeadline(time.Now().Add(pongWait))
252 c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
253 for {
254 _, message, err := c.ws.ReadMessage()
255 if err != nil {
256 break
257 }
258 if debug {
259 log.Printf("websocket: got message %#q", message)
260 }
261 cm := new(wsClientMessage)
262 if err := json.Unmarshal(message, cm); err != nil {
263 log.Printf("Ignoring bogus websocket message. Err: %v", err)
264 continue
265 }
266 c.sh.wsHub.watchReq <- watchReq{
267 conn: c,
268 tag: cm.Tag,
269 q: cm.Query,
270 }
271 }
272 }
273
274
275 func (c *wsConn) write(mt int, payload []byte) error {
276 c.ws.SetWriteDeadline(time.Now().Add(writeWait))
277 return c.ws.WriteMessage(mt, payload)
278 }
279
280
281 func (c *wsConn) writePump() {
282 ticker := time.NewTicker(pingPeriod)
283 defer func() {
284 ticker.Stop()
285 c.ws.Close()
286 }()
287 for {
288 select {
289 case message, ok := <-c.send:
290 if !ok {
291 c.write(websocket.CloseMessage, []byte{})
292 return
293 }
294 if err := c.write(websocket.TextMessage, message); err != nil {
295 return
296 }
297 case <-ticker.C:
298 if err := c.write(websocket.PingMessage, []byte{}); err != nil {
299 return
300 }
301 }
302 }
303 }
304
305
306 var upgrader = websocket.Upgrader{
307 ReadBufferSize: 1024,
308 WriteBufferSize: 1024,
309
310 }
311
312 func (h *Handler) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
313 ws, err := upgrader.Upgrade(rw, req, nil)
314 var he websocket.HandshakeError
315 if errors.As(err, &he) {
316 http.Error(rw, "Not a websocket handshake", http.StatusBadRequest)
317 return
318 } else if err != nil {
319 log.Println(err)
320 return
321 }
322 c := &wsConn{
323 ws: ws,
324 send: make(chan []byte, 256),
325 sh: h,
326 queries: make(map[string]*watchedQuery),
327 }
328 h.wsHub.register <- c
329 go c.writePump()
330 c.readPump()
331 }