diff --git a/pkg/s3select/message.go b/pkg/s3select/message.go index 9b4b202f4..60ce9238e 100644 --- a/pkg/s3select/message.go +++ b/pkg/s3select/message.go @@ -226,12 +226,12 @@ type messageWriter struct { payloadBuffer []byte payloadBufferIndex int + payloadCh chan []byte - dataCh chan []byte - doneCh chan struct{} - closeCh chan struct{} - stopped uint32 - closed uint32 + finBytesScanned, finBytesProcessed int64 + + errCh chan []byte + doneCh chan struct{} } func (writer *messageWriter) write(data []byte) bool { @@ -246,80 +246,89 @@ func (writer *messageWriter) write(data []byte) bool { func (writer *messageWriter) start() { keepAliveTicker := time.NewTicker(1 * time.Second) var progressTicker *time.Ticker + var progressTickerC <-chan time.Time if writer.getProgressFunc != nil { progressTicker = time.NewTicker(1 * time.Minute) + progressTickerC = progressTicker.C } + recordStagingTicker := time.NewTicker(500 * time.Millisecond) + + // Exit conditions: + // + // 1. If a writer.write() returns false, select loop below exits and + // closes `doneCh` to indicate to caller to also exit. + // + // 2. If caller (Evaluate()) has an error, it sends an error + // message and waits for this go-routine to quit in + // FinishWithError() + // + // 3. If caller is done, it waits for this go-routine to exit + // in Finish() + + quitFlag := false + for !quitFlag { + select { + case data := <-writer.errCh: + quitFlag = true + // Flush collected records before sending error message + if !writer.flushRecords() { + break + } + writer.write(data) - go func() { - quitFlag := 0 - for quitFlag == 0 { - if progressTicker == nil { - select { - case data, ok := <-writer.dataCh: - if !ok { - quitFlag = 1 - break - } - if !writer.write(data) { - quitFlag = 1 - } - case <-writer.doneCh: - quitFlag = 2 - case <-keepAliveTicker.C: - if !writer.write(continuationMessage) { - quitFlag = 1 - } + case payload, ok := <-writer.payloadCh: + if !ok { + // payloadCh is closed by caller to + // indicate finish with success + quitFlag = true + + if !writer.flushRecords() { + break + } + // Write Stats message, then End message + bytesReturned := atomic.LoadInt64(&writer.bytesReturned) + if !writer.write(newStatsMessage(writer.finBytesScanned, writer.finBytesProcessed, bytesReturned)) { + break } + writer.write(endMessage) } else { - select { - case data, ok := <-writer.dataCh: - if !ok { - quitFlag = 1 + // Write record payload to staging buffer + freeSpace := bufLength - writer.payloadBufferIndex + if len(payload) > freeSpace { + if !writer.flushRecords() { + quitFlag = true break } - if !writer.write(data) { - quitFlag = 1 - } - case <-writer.doneCh: - quitFlag = 2 - case <-keepAliveTicker.C: - if !writer.write(continuationMessage) { - quitFlag = 1 - } - case <-progressTicker.C: - bytesScanned, bytesProcessed := writer.getProgressFunc() - bytesReturned := atomic.LoadInt64(&writer.bytesReturned) - if !writer.write(newProgressMessage(bytesScanned, bytesProcessed, bytesReturned)) { - quitFlag = 1 - } } + copy(writer.payloadBuffer[writer.payloadBufferIndex:], payload) + writer.payloadBufferIndex += len(payload) } - } - atomic.StoreUint32(&writer.stopped, 1) - close(writer.closeCh) + case <-recordStagingTicker.C: + if !writer.flushRecords() { + quitFlag = true + break + } - keepAliveTicker.Stop() - if progressTicker != nil { - progressTicker.Stop() - } + case <-keepAliveTicker.C: + if !writer.write(continuationMessage) { + quitFlag = true + } - if quitFlag == 2 { - for data := range writer.dataCh { - if _, err := writer.writer.Write(data); err != nil { - break - } + case <-progressTickerC: + bytesScanned, bytesProcessed := writer.getProgressFunc() + bytesReturned := atomic.LoadInt64(&writer.bytesReturned) + if !writer.write(newProgressMessage(bytesScanned, bytesProcessed, bytesReturned)) { + quitFlag = true } } - }() -} + } + close(writer.doneCh) -func (writer *messageWriter) close() { - if atomic.SwapUint32(&writer.closed, 1) == 0 { - close(writer.doneCh) - for range writer.closeCh { - close(writer.dataCh) - } + recordStagingTicker.Stop() + keepAliveTicker.Stop() + if progressTicker != nil { + progressTicker.Stop() } } @@ -327,88 +336,69 @@ const ( bufLength = maxRecordSize ) -// collectRecord - collects records into a buffer, and when it is -// full, sends a message with the collected payload. -func (writer *messageWriter) collectRecord(data []byte) (err error) { - freeSpace := bufLength - writer.payloadBufferIndex - if len(data) > freeSpace { - err = writer.FlushRecords() - if err != nil { - return err - } - } - copy(writer.payloadBuffer[writer.payloadBufferIndex:], data) - writer.payloadBufferIndex += len(data) - return nil -} - -func (writer *messageWriter) send(data []byte) error { - err := func() error { - if atomic.LoadUint32(&writer.stopped) == 1 { - return fmt.Errorf("writer already closed") - } - - select { - case writer.dataCh <- data: - case <-writer.doneCh: - return fmt.Errorf("closed writer") - } - +// Sends a single whole record. +func (writer *messageWriter) SendRecord(payload []byte) error { + select { + case writer.payloadCh <- payload: return nil - }() - - if err != nil { - writer.close() + case <-writer.doneCh: + return fmt.Errorf("messageWriter is done") } - - return err -} - -func (writer *messageWriter) SendRecords(payload []byte) error { - return writer.collectRecord(payload) } -func (writer *messageWriter) FlushRecords() (err error) { - err = writer.send(newRecordsMessage(writer.payloadBuffer[0:writer.payloadBufferIndex])) - if err != nil { - return err +func (writer *messageWriter) flushRecords() bool { + if writer.payloadBufferIndex == 0 { + return true + } + result := writer.write(newRecordsMessage(writer.payloadBuffer[0:writer.payloadBufferIndex])) + if result { + atomic.AddInt64(&writer.bytesReturned, int64(writer.payloadBufferIndex)) + writer.payloadBufferIndex = 0 } - atomic.AddInt64(&writer.bytesReturned, int64(writer.payloadBufferIndex)) - writer.payloadBufferIndex = 0 - return nil + return result } -func (writer *messageWriter) SendStats(bytesScanned, bytesProcessed int64) error { - bytesReturned := atomic.LoadInt64(&writer.bytesReturned) - err := writer.send(newStatsMessage(bytesScanned, bytesProcessed, bytesReturned)) - if err != nil { - return err +// Finish is the last call to the message writer - it sends any +// remaining record payload, then sends statistics and finally the end +// message. +func (writer *messageWriter) Finish(bytesScanned, bytesProcessed int64) error { + select { + case <-writer.doneCh: + return fmt.Errorf("messageWriter is done") + default: + writer.finBytesScanned = bytesScanned + writer.finBytesProcessed = bytesProcessed + close(writer.payloadCh) + // Wait until the `start` go-routine is done. + <-writer.doneCh + return nil } - - err = writer.send(endMessage) - writer.close() - return err } -func (writer *messageWriter) SendError(errorCode, errorMessage string) error { - err := writer.send(newErrorMessage([]byte(errorCode), []byte(errorMessage))) - if err == nil { - writer.close() +func (writer *messageWriter) FinishWithError(errorCode, errorMessage string) error { + select { + case <-writer.doneCh: + return fmt.Errorf("messageWriter is done") + case writer.errCh <- newErrorMessage([]byte(errorCode), []byte(errorMessage)): + // Wait until the `start` go-routine is done. + <-writer.doneCh + return nil } - return err } +// newMessageWriter creates a message writer that writes to the HTTP +// response writer func newMessageWriter(w http.ResponseWriter, getProgressFunc func() (bytesScanned, bytesProcessed int64)) *messageWriter { writer := &messageWriter{ writer: w, getProgressFunc: getProgressFunc, payloadBuffer: make([]byte, bufLength), + payloadCh: make(chan []byte), - dataCh: make(chan []byte), - doneCh: make(chan struct{}), - closeCh: make(chan struct{}), + errCh: make(chan []byte), + doneCh: make(chan struct{}), } - writer.start() + go writer.start() return writer } diff --git a/pkg/s3select/select.go b/pkg/s3select/select.go index b378b4485..8f6b46c04 100644 --- a/pkg/s3select/select.go +++ b/pkg/s3select/select.go @@ -329,11 +329,11 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) { } if len(data) > maxRecordSize { - writer.SendError("OverMaxRecordSize", "The length of a record in the input or result is greater than maxCharsPerRecord of 1 MB.") + writer.FinishWithError("OverMaxRecordSize", "The length of a record in the input or result is greater than maxCharsPerRecord of 1 MB.") return false } - if err = writer.SendRecords(data); err != nil { + if err = writer.SendRecord(data); err != nil { // FIXME: log this error. err = nil return false @@ -344,13 +344,7 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) { for { if s3Select.statement.LimitReached() { - if err = writer.FlushRecords(); err != nil { - // FIXME: log this error - err = nil - break - } - - if err = writer.SendStats(s3Select.getProgress()); err != nil { + if err = writer.Finish(s3Select.getProgress()); err != nil { // FIXME: log this error. err = nil } @@ -373,17 +367,10 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) { } } - if err = writer.FlushRecords(); err != nil { - // FIXME: log this error - err = nil - break - } - - if err = writer.SendStats(s3Select.getProgress()); err != nil { + if err = writer.Finish(s3Select.getProgress()); err != nil { // FIXME: log this error. err = nil } - break } @@ -404,8 +391,7 @@ func (s3Select *S3Select) Evaluate(w http.ResponseWriter) { } if err != nil { - fmt.Printf("SQL Err: %#v\n", err) - if serr := writer.SendError("InternalError", err.Error()); serr != nil { + if serr := writer.FinishWithError("InternalError", err.Error()); serr != nil { // FIXME: log errors. } }