1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
16
17 package search
18
19 import (
20 "bytes"
21 "context"
22 "encoding/json"
23 "log"
24 "net/http"
25 "os"
26 "strconv"
27 "sync"
28 "time"
29
30 "github.com/gorilla/websocket"
31 "perkeep.org/pkg/schema"
32 )
33
34 const (
35
36 writeWait = 10 * time.Second
37
38
39 pongWait = 60 * time.Second
40
41
42 pingPeriod = (pongWait * 9) / 10
43
44
45 maxMessageSize = 10 << 10
46 )
47
48 var debug, _ = strconv.ParseBool(os.Getenv("CAMLI_DEBUG"))
49
50 type wsHub struct {
51 sh *Handler
52 register chan *wsConn
53 unregister chan *wsConn
54 watchReq chan watchReq
55 newBlobRecv chan schema.CamliType
56 updatedResults chan *watchedQuery
57 statusUpdate chan json.RawMessage
58
59
60 conns map[*wsConn]bool
61 }
62
63 func newWebsocketHub(sh *Handler) *wsHub {
64 return &wsHub{
65 sh: sh,
66 register: make(chan *wsConn),
67 unregister: make(chan *wsConn),
68 conns: make(map[*wsConn]bool),
69 watchReq: make(chan watchReq, buffered),
70 newBlobRecv: make(chan schema.CamliType, buffered),
71 updatedResults: make(chan *watchedQuery, buffered),
72 statusUpdate: make(chan json.RawMessage, buffered),
73 }
74 }
75
76 func (h *wsHub) run() {
77 var lastStatusMsg []byte
78 for {
79 select {
80 case st := <-h.statusUpdate:
81 const prefix = `{"tag":"_status","status":`
82 lastStatusMsg = make([]byte, 0, len(prefix)+len(st)+1)
83 lastStatusMsg = append(lastStatusMsg, prefix...)
84 lastStatusMsg = append(lastStatusMsg, st...)
85 lastStatusMsg = append(lastStatusMsg, '}')
86 for c := range h.conns {
87 c.send <- lastStatusMsg
88 }
89 case c := <-h.register:
90 h.conns[c] = true
91 c.send <- lastStatusMsg
92 case c := <-h.unregister:
93 delete(h.conns, c)
94 close(c.send)
95 case camliType := <-h.newBlobRecv:
96 if camliType == "" {
97
98
99
100
101
102
103 continue
104 }
105
106 for conn := range h.conns {
107 for _, wq := range conn.queries {
108 go h.redoSearch(wq)
109 }
110 }
111 case wr := <-h.watchReq:
112
113 if wr.q == nil {
114 delete(wr.conn.queries, wr.tag)
115 log.Printf("Removed subscription for %v, %q", wr.conn, wr.tag)
116 continue
117 }
118
119
120 wq := &watchedQuery{
121 conn: wr.conn,
122 tag: wr.tag,
123 q: wr.q,
124 }
125 wr.conn.queries[wr.tag] = wq
126 if debug {
127 log.Printf("websocket: added/updated search subscription for tag %q", wr.tag)
128 }
129 go h.doSearch(wq)
130
131 case wq := <-h.updatedResults:
132 if !h.conns[wq.conn] || wq.conn.queries[wq.tag] == nil {
133
134 continue
135 }
136 wq.mu.Lock()
137 lastres := wq.lastres
138 wq.mu.Unlock()
139 resb, err := json.Marshal(wsUpdateMessage{
140 Tag: wq.tag,
141 Result: lastres,
142 })
143 if err != nil {
144 panic(err)
145 }
146 wq.conn.send <- resb
147 }
148 }
149 }
150
151
152
153
154
155 func (h *wsHub) redoSearch(wq *watchedQuery) {
156 wq.mu.Lock()
157 defer wq.mu.Unlock()
158 wq.dirty = true
159 if wq.refreshing {
160
161
162 return
163 }
164 for wq.dirty {
165 wq.refreshing = true
166 wq.dirty = false
167 wq.mu.Unlock()
168 h.doSearch(wq)
169 wq.mu.Lock()
170 }
171 wq.refreshing = false
172 }
173
174 func (h *wsHub) doSearch(wq *watchedQuery) {
175
176 q := new(SearchQuery)
177 *q = *wq.q
178 if q.Describe != nil {
179 q.Describe = wq.q.Describe.Clone()
180 }
181
182 res, err := h.sh.Query(context.TODO(), q)
183 if err != nil {
184 log.Printf("Query error: %v", err)
185 return
186 }
187 resj, _ := json.Marshal(res)
188
189 wq.mu.Lock()
190 eq := bytes.Equal(wq.lastresj, resj)
191 wq.lastres = res
192 wq.lastresj = resj
193 wq.mu.Unlock()
194 if eq {
195
196 return
197 }
198 h.updatedResults <- wq
199 }
200
201 type wsConn struct {
202 ws *websocket.Conn
203 send chan []byte
204 sh *Handler
205
206
207 queries map[string]*watchedQuery
208 }
209
210 type watchedQuery struct {
211 conn *wsConn
212 tag string
213 q *SearchQuery
214
215 mu sync.Mutex
216 refreshing bool
217 dirty bool
218 lastres *SearchResult
219 lastresj []byte
220 }
221
222
223 type watchReq struct {
224 conn *wsConn
225 tag string
226 q *SearchQuery
227 }
228
229
230 type wsClientMessage struct {
231
232 Tag string `json:"tag"`
233
234 Query *SearchQuery `json:"query,omitempty"`
235 }
236
237 type wsUpdateMessage struct {
238 Tag string `json:"tag"`
239 Result *SearchResult `json:"result,omitempty"`
240 }
241
242
243 func (c *wsConn) readPump() {
244 defer func() {
245 c.sh.wsHub.unregister <- c
246 c.ws.Close()
247 }()
248 c.ws.SetReadLimit(maxMessageSize)
249 c.ws.SetReadDeadline(time.Now().Add(pongWait))
250 c.ws.SetPongHandler(func(string) error { c.ws.SetReadDeadline(time.Now().Add(pongWait)); return nil })
251 for {
252 _, message, err := c.ws.ReadMessage()
253 if err != nil {
254 break
255 }
256 if debug {
257 log.Printf("websocket: got message %#q", message)
258 }
259 cm := new(wsClientMessage)
260 if err := json.Unmarshal(message, cm); err != nil {
261 log.Printf("Ignoring bogus websocket message. Err: %v", err)
262 continue
263 }
264 c.sh.wsHub.watchReq <- watchReq{
265 conn: c,
266 tag: cm.Tag,
267 q: cm.Query,
268 }
269 }
270 }
271
272
273 func (c *wsConn) write(mt int, payload []byte) error {
274 c.ws.SetWriteDeadline(time.Now().Add(writeWait))
275 return c.ws.WriteMessage(mt, payload)
276 }
277
278
279 func (c *wsConn) writePump() {
280 ticker := time.NewTicker(pingPeriod)
281 defer func() {
282 ticker.Stop()
283 c.ws.Close()
284 }()
285 for {
286 select {
287 case message, ok := <-c.send:
288 if !ok {
289 c.write(websocket.CloseMessage, []byte{})
290 return
291 }
292 if err := c.write(websocket.TextMessage, message); err != nil {
293 return
294 }
295 case <-ticker.C:
296 if err := c.write(websocket.PingMessage, []byte{}); err != nil {
297 return
298 }
299 }
300 }
301 }
302
303
304 var upgrader = websocket.Upgrader{
305 ReadBufferSize: 1024,
306 WriteBufferSize: 1024,
307
308 }
309
310 func (sh *Handler) serveWebSocket(rw http.ResponseWriter, req *http.Request) {
311 ws, err := upgrader.Upgrade(rw, req, nil)
312 if _, ok := err.(websocket.HandshakeError); ok {
313 http.Error(rw, "Not a websocket handshake", http.StatusBadRequest)
314 return
315 } else if err != nil {
316 log.Println(err)
317 return
318 }
319 c := &wsConn{
320 ws: ws,
321 send: make(chan []byte, 256),
322 sh: sh,
323 queries: make(map[string]*watchedQuery),
324 }
325 sh.wsHub.register <- c
326 go c.writePump()
327 c.readPump()
328 }