diff --git a/mint/Dockerfile.dev b/mint/Dockerfile.dev index fd98d0cd9..9e4c30326 100644 --- a/mint/Dockerfile.dev +++ b/mint/Dockerfile.dev @@ -74,6 +74,9 @@ RUN build/worm/install.sh COPY build/healthcheck /mint/build/healthcheck RUN build/healthcheck/install.sh +COPY build/s3select /mint/build/s3select +RUN build/s3select/install.sh + COPY remove-packages.list /mint COPY postinstall.sh /mint RUN /mint/postinstall.sh diff --git a/mint/build/s3select/install.sh b/mint/build/s3select/install.sh new file mode 100755 index 000000000..aae2af2f6 --- /dev/null +++ b/mint/build/s3select/install.sh @@ -0,0 +1,18 @@ +#!/bin/bash -e +# +# Mint (C) 2020 Minio, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +python -m pip install minio diff --git a/mint/run/core/s3select/README.md b/mint/run/core/s3select/README.md new file mode 100644 index 000000000..73a69a22b --- /dev/null +++ b/mint/run/core/s3select/README.md @@ -0,0 +1,21 @@ + +## `s3select` tests +This directory serves as the location for Mint tests for s3select features. Top level `mint.sh` calls `run.sh` to execute tests. + +## Adding new tests +New tests are added into `s3select/tests.py` as new functions. + +## Running tests manually +- Set environment variables `MINT_DATA_DIR`, `MINT_MODE`, `SERVER_ENDPOINT`, `ACCESS_KEY`, `SECRET_KEY`, `SERVER_REGION` and `ENABLE_HTTPS` +- Call `run.sh` with output log file and error log file. for example + +```bash +export MINT_DATA_DIR=~/my-mint-dir +export MINT_MODE=core +export SERVER_ENDPOINT="play.min.io" +export ACCESS_KEY="Q3AM3UQ867SPQQA43P2F" +export SECRET_KEY="zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG" +export ENABLE_HTTPS=1 +export SERVER_REGION=us-east-1 +./run.sh /tmp/output.log /tmp/error.log +``` diff --git a/mint/run/core/s3select/run.sh b/mint/run/core/s3select/run.sh new file mode 100755 index 000000000..46db553e9 --- /dev/null +++ b/mint/run/core/s3select/run.sh @@ -0,0 +1,28 @@ +#!/bin/bash +# +# Mint (C) 2020 Minio, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# handle command line arguments +if [ $# -ne 2 ]; then + echo "usage: run.sh " + exit -1 +fi + +output_log_file="$1" +error_log_file="$2" + +# run path style tests +python "./tests.py" 1>>"$output_log_file" 2>"$error_log_file" diff --git a/mint/run/core/s3select/tests.py b/mint/run/core/s3select/tests.py new file mode 100644 index 000000000..145598a7f --- /dev/null +++ b/mint/run/core/s3select/tests.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# MinIO Python Library for Amazon S3 Compatible Cloud Storage, +# (C) 2015-2020 MinIO, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from __future__ import division +# from __future__ import absolute_import + +import os +import io +from sys import exit +import uuid +import inspect +import json +import time +import traceback + +from minio import Minio +from minio.select.options import (SelectObjectOptions, CSVInput, + RequestProgress, InputSerialization, + OutputSerialization, CSVOutput, JsonOutput) + +class LogOutput(object): + """ + LogOutput is the class for log output. It is required standard for all + SDK tests controlled by mint. + Here are its attributes: + 'name': name of the SDK under test, e.g. 's3select' + 'function': name of the method/api under test with its signature + The following python code can be used to + pull args information of a and to + put together with the method name: + .__name__+'('+', '.join(args_list)+')' + e.g. 'remove_object(bucket_name, object_name)' + 'args': method/api arguments with their values, in + dictionary form: {'arg1': val1, 'arg2': val2, ...} + 'duration': duration of the whole test in milliseconds, + defaults to 0 + 'alert': any extra information user is needed to be alerted about, + like whether this is a Blocker/Gateway/Server related + issue, etc., defaults to None + 'message': descriptive error message, defaults to None + 'error': stack-trace/exception message(only in case of failure), + actual low level exception/error thrown by the program, + defaults to None + 'status': exit status, possible values are 'PASS', 'FAIL', 'NA', + defaults to 'PASS' + """ + + PASS = 'PASS' + FAIL = 'FAIL' + NA = 'NA' + + def __init__(self, meth, test_name): + self.__args_list = inspect.getargspec(meth).args[1:] + self.__name = 'minio-py:'+test_name + self.__function = meth.__name__+'('+', '.join(self.__args_list)+')' + self.__args = {} + self.__duration = 0 + self.__alert = '' + self.__message = None + self.__error = None + self.__status = self.PASS + self.__start_time = time.time() + + @property + def name(self): return self.__name + + @property + def function(self): return self.__function + + @property + def args(self): return self.__args + + @name.setter + def name(self, val): self.__name = val + + @function.setter + def function(self, val): self.__function = val + + @args.setter + def args(self, val): self.__args = val + + def json_report(self, err_msg='', alert='', status=''): + self.__args = {k: v for k, v in self.__args.items() if v and v != ''} + entry = {'name': self.__name, + 'function': self.__function, + 'args': self.__args, + 'duration': int(round((time.time() - self.__start_time)*1000)), + 'alert': str(alert), + 'message': str(err_msg), + 'error': traceback.format_exc() if err_msg and err_msg != '' else '', + 'status': status if status and status != '' else + self.FAIL if err_msg and err_msg != '' else self.PASS + } + return json.dumps({k: v for k, v in entry.items() if v and v != ''}) + +def generate_bucket_name(): + return "s3select-test-" + uuid.uuid4().__str__() + + +def test_csv_input_quote_char(client, log_output): + # Get a unique bucket_name and object_name + log_output.args['bucket_name'] = bucket_name = generate_bucket_name() + + tests = [ + # Invalid quote character, should fail + ('""', b'col1,col2,col3\n', Exception()), + # UTF-8 quote character + ('ع', b'\xd8\xb9col1\xd8\xb9,\xd8\xb9col2\xd8\xb9,\xd8\xb9col3\xd8\xb9\n', b'{"_1":"col1","_2":"col2","_3":"col3"}\n'), + # Only one field is quoted + ('"', b'"col1",col2,col3\n', b'{"_1":"col1","_2":"col2","_3":"col3"}\n'), + ('"', b'"col1,col2,col3"\n', b'{"_1":"col1,col2,col3"}\n'), + ('\'', b'"col1",col2,col3\n', b'{"_1":"\\"col1\\"","_2":"col2","_3":"col3"}\n'), + ('', b'"col1",col2,col3\n', b'{"_1":"\\"col1\\"","_2":"col2","_3":"col3"}\n'), + ('', b'"col1",col2,col3\n', b'{"_1":"\\"col1\\"","_2":"col2","_3":"col3"}\n'), + ('', b'"col1","col2","col3"\n', b'{"_1":"\\"col1\\"","_2":"\\"col2\\"","_3":"\\"col3\\""}\n'), + ('"', b'""""""\n', b'{"_1":"\\"\\""}\n'), + ] + + try: + client.make_bucket(bucket_name) + + for idx, (quote_char, object_content, expected_output) in enumerate(tests): + options = SelectObjectOptions( + expression="select * from s3object", + input_serialization=InputSerialization( + compression_type="NONE", + csv=CSVInput(FileHeaderInfo="NONE", + RecordDelimiter="\n", + FieldDelimiter=",", + QuoteCharacter=quote_char, + QuoteEscapeCharacter=quote_char, + Comments="#", + AllowQuotedRecordDelimiter="FALSE",), + ), + output_serialization=OutputSerialization( + json = JsonOutput( + RecordDelimiter="\n", + ) + ), + request_progress=RequestProgress( + enabled="False" + ) + ) + + got_output = b'' + + try: + got_output = exec_select(client, bucket_name, object_content, options, log_output) + except Exception as select_err: + if not isinstance(expected_output, Exception): + raise ValueError('Test {} unexpectedly failed with: {}'.format(idx+1, select_err)) + else: + if isinstance(expected_output, Exception): + raise ValueError('Test {}: expected an exception, got {}'.format(idx+1, got_output)) + if got_output != expected_output: + raise ValueError('Test {}: data mismatch. Expected : {}, Received {}'.format(idx+1, expected_output, got_output)) + + except Exception as err: + raise Exception(err) + finally: + try: + client.remove_bucket(bucket_name) + except Exception as err: + raise Exception(err) + + # Test passes + print(log_output.json_report()) + +def test_csv_output_quote_char(client, log_output): + # Get a unique bucket_name and object_name + log_output.args['bucket_name'] = bucket_name = generate_bucket_name() + + tests = [ + # UTF-8 quote character + ("''", b'col1,col2,col3\n', Exception()), + ("'", b'col1,col2,col3\n', b"'col1','col2','col3'\n"), + ("", b'col1,col2,col3\n', b'\x00col1\x00,\x00col2\x00,\x00col3\x00\n'), + ('"', b'col1,col2,col3\n', b'"col1","col2","col3"\n'), + ('"', b'col"1,col2,col3\n', b'"col""1","col2","col3"\n'), + ('"', b'\n', b''), + ] + + try: + client.make_bucket(bucket_name) + + for idx, (quote_char, object_content, expected_output) in enumerate(tests): + options = SelectObjectOptions( + expression="select * from s3object", + input_serialization=InputSerialization( + compression_type="NONE", + csv=CSVInput(FileHeaderInfo="NONE", + RecordDelimiter="\n", + FieldDelimiter=",", + QuoteCharacter='"', + QuoteEscapeCharacter='"', + Comments="#", + AllowQuotedRecordDelimiter="FALSE",), + ), + output_serialization=OutputSerialization( + csv=CSVOutput(QuoteFields="ALWAYS", + RecordDelimiter="\n", + FieldDelimiter=",", + QuoteCharacter=quote_char, + QuoteEscapeCharacter=quote_char,) + ), + request_progress=RequestProgress( + enabled="False" + ) + ) + + got_output = b'' + + try: + got_output = exec_select(client, bucket_name, object_content, options, log_output) + except Exception as select_err: + if not isinstance(expected_output, Exception): + raise ValueError('Test {} unexpectedly failed with: {}'.format(idx+1, select_err)) + else: + if isinstance(expected_output, Exception): + raise ValueError('Test {}: expected an exception, got {}'.format(idx+1, got_output)) + if got_output != expected_output: + raise ValueError('Test {}: data mismatch. Expected : {}. Received: {}.'.format(idx+1, expected_output, got_output)) + + except Exception as err: + raise Exception(err) + finally: + try: + client.remove_bucket(bucket_name) + except Exception as err: + raise Exception(err) + + # Test passes + print(log_output.json_report()) + + +def exec_select(client, bucket_name, object_content, options, log_output): + log_output.args['object_name'] = object_name = uuid.uuid4().__str__() + try: + bytes_content = io.BytesIO(object_content) + client.put_object(bucket_name, object_name, io.BytesIO(object_content), len(object_content)) + + data = client.select_object_content(bucket_name, object_name, options) + # Get the records + records = io.BytesIO() + for d in data.stream(10*1024): + records.write(d.encode('utf-8')) + + return records.getvalue() + + except Exception as err: + raise Exception(err) + finally: + try: + client.remove_object(bucket_name, object_name) + except Exception as err: + raise Exception(err) + + +def main(): + """ + Functional testing for S3 select. + """ + + try: + access_key = os.getenv('ACCESS_KEY', 'Q3AM3UQ867SPQQA43P2F') + secret_key = os.getenv('SECRET_KEY', + 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG') + server_endpoint = os.getenv('SERVER_ENDPOINT', 'play.min.io') + secure = os.getenv('ENABLE_HTTPS', '1') == '1' + if server_endpoint == 'play.min.io': + access_key = 'Q3AM3UQ867SPQQA43P2F' + secret_key = 'zuf+tfteSlswRu7BJ86wekitnifILbZam1KYY3TG' + secure = True + + client = Minio(server_endpoint, access_key, secret_key, secure=secure) + + log_output = LogOutput(client.select_object_content, 'test_csv_input_quote_char') + test_csv_input_quote_char(client, log_output) + + log_output = LogOutput(client.select_object_content, 'test_csv_output_quote_char') + test_csv_output_quote_char(client, log_output) + + + except Exception as err: + print(log_output.json_report(err)) + exit(1) + +if __name__ == "__main__": + # Execute only if run as a script + main() diff --git a/pkg/csvparser/reader.go b/pkg/csvparser/reader.go index 8163cc495..a50ef6dcc 100644 --- a/pkg/csvparser/reader.go +++ b/pkg/csvparser/reader.go @@ -113,6 +113,9 @@ type Reader struct { // or the Unicode replacement character (0xFFFD). Comma rune + // Quote is the single character used for marking fields limits + Quote []rune + // Comment, if not 0, is the comment character. Lines beginning with the // Comment character without preceding whitespace are ignored. // With leading whitespace the Comment character becomes part of the @@ -171,6 +174,7 @@ type Reader struct { func NewReader(r io.Reader) *Reader { return &Reader{ Comma: ',', + Quote: []rune(`"`), r: bufio.NewReader(r), } } @@ -255,6 +259,13 @@ func nextRune(b []byte) rune { return r } +func encodeRune(r rune) []byte { + rlen := utf8.RuneLen(r) + p := make([]byte, rlen) + _ = utf8.EncodeRune(p, r) + return p +} + func (r *Reader) readRecord(dst []string) ([]string, error) { if r.Comma == r.Comment || !validDelim(r.Comma) || (r.Comment != 0 && !validDelim(r.Comment)) { return nil, errInvalidDelim @@ -280,9 +291,17 @@ func (r *Reader) readRecord(dst []string) ([]string, error) { return nil, errRead } + var quote rune + var quoteLen int + if len(r.Quote) > 0 { + quote = r.Quote[0] + quoteLen = utf8.RuneLen(quote) + } + + encodedQuote := encodeRune(quote) + // Parse each field in the record. var err error - const quoteLen = len(`"`) commaLen := utf8.RuneLen(r.Comma) recLine := r.numLine // Starting line for record r.recordBuffer = r.recordBuffer[:0] @@ -292,7 +311,7 @@ parseField: if r.TrimLeadingSpace { line = bytes.TrimLeftFunc(line, unicode.IsSpace) } - if len(line) == 0 || line[0] != '"' { + if len(line) == 0 || quoteLen == 0 || nextRune(line) != quote { // Non-quoted string field i := bytes.IndexRune(line, r.Comma) field := line @@ -303,7 +322,7 @@ parseField: } // Check to make sure a quote does not appear in field. if !r.LazyQuotes { - if j := bytes.IndexByte(field, '"'); j >= 0 { + if j := bytes.IndexRune(field, quote); j >= 0 { col := utf8.RuneCount(fullLine[:len(fullLine)-len(line[j:])]) err = &ParseError{StartLine: recLine, Line: r.numLine, Column: col, Err: ErrBareQuote} break parseField @@ -320,15 +339,15 @@ parseField: // Quoted string field line = line[quoteLen:] for { - i := bytes.IndexByte(line, '"') + i := bytes.IndexRune(line, quote) if i >= 0 { // Hit next quote. r.recordBuffer = append(r.recordBuffer, line[:i]...) line = line[i+quoteLen:] switch rn := nextRune(line); { - case rn == '"': + case rn == quote: // `""` sequence (append quote). - r.recordBuffer = append(r.recordBuffer, '"') + r.recordBuffer = append(r.recordBuffer, encodedQuote...) line = line[quoteLen:] case rn == r.Comma: // `",` sequence (end of field). @@ -341,7 +360,7 @@ parseField: break parseField case r.LazyQuotes: // `"` sequence (bare quote). - r.recordBuffer = append(r.recordBuffer, '"') + r.recordBuffer = append(r.recordBuffer, encodedQuote...) default: // `"*` sequence (invalid non-escaped quote). col := utf8.RuneCount(fullLine[:len(fullLine)-len(line)-quoteLen]) diff --git a/pkg/csvparser/writer.go b/pkg/csvparser/writer.go index 6255acd97..cdcfc42b4 100644 --- a/pkg/csvparser/writer.go +++ b/pkg/csvparser/writer.go @@ -28,15 +28,18 @@ import ( // the underlying io.Writer. Any errors that occurred should // be checked by calling the Error method. type Writer struct { - Comma rune // Field delimiter (set to ',' by NewWriter) - UseCRLF bool // True to use \r\n as the line terminator - w *bufio.Writer + Comma rune // Field delimiter (set to ',' by NewWriter) + Quote rune // Fields quote character + AlwaysQuote bool // True to quote all fields + UseCRLF bool // True to use \r\n as the line terminator + w *bufio.Writer } // NewWriter returns a new Writer that writes to w. func NewWriter(w io.Writer) *Writer { return &Writer{ Comma: ',', + Quote: '"', w: bufio.NewWriter(w), } } @@ -59,19 +62,22 @@ func (w *Writer) Write(record []string) error { // If we don't have to have a quoted field then just // write out the field and continue to the next field. - if !w.fieldNeedsQuotes(field) { + if !w.AlwaysQuote && !w.fieldNeedsQuotes(field) { if _, err := w.w.WriteString(field); err != nil { return err } continue } - if err := w.w.WriteByte('"'); err != nil { + if _, err := w.w.WriteRune(w.Quote); err != nil { return err } + + specialChars := "\r\n" + string(w.Quote) + for len(field) > 0 { // Search for special characters. - i := strings.IndexAny(field, "\"\r\n") + i := strings.IndexAny(field, specialChars) if i < 0 { i = len(field) } @@ -85,9 +91,13 @@ func (w *Writer) Write(record []string) error { // Encode the special character. if len(field) > 0 { var err error - switch field[0] { - case '"': - _, err = w.w.WriteString(`""`) + switch nextRune([]byte(field)) { + case w.Quote: + _, err = w.w.WriteRune(w.Quote) + if err != nil { + break + } + _, err = w.w.WriteRune(w.Quote) case '\r': if !w.UseCRLF { err = w.w.WriteByte('\r') @@ -105,7 +115,7 @@ func (w *Writer) Write(record []string) error { } } } - if err := w.w.WriteByte('"'); err != nil { + if _, err := w.w.WriteRune(w.Quote); err != nil { return err } } @@ -158,7 +168,7 @@ func (w *Writer) fieldNeedsQuotes(field string) bool { if field == "" { return false } - if field == `\.` || strings.ContainsRune(field, w.Comma) || strings.ContainsAny(field, "\"\r\n") { + if field == `\.` || strings.ContainsAny(field, "\r\n"+string(w.Quote)+string(w.Comma)) { return true } diff --git a/pkg/csvparser/writer_test.go b/pkg/csvparser/writer_test.go index 7cb15a486..e64c15400 100644 --- a/pkg/csvparser/writer_test.go +++ b/pkg/csvparser/writer_test.go @@ -11,11 +11,13 @@ import ( ) var writeTests = []struct { - Input [][]string - Output string - Error error - UseCRLF bool - Comma rune + Input [][]string + Output string + Error error + UseCRLF bool + Comma rune + Quote rune + AlwaysQuote bool }{ {Input: [][]string{{"abc"}}, Output: "abc\n"}, {Input: [][]string{{"abc"}}, Output: "abc\r\n", UseCRLF: true}, @@ -46,6 +48,7 @@ var writeTests = []struct { {Input: [][]string{{"a", "a", ""}}, Output: "a|a|\n", Comma: '|'}, {Input: [][]string{{",", ",", ""}}, Output: ",|,|\n", Comma: '|'}, {Input: [][]string{{"foo"}}, Comma: '"', Error: errInvalidDelim}, + {Input: [][]string{{"a", "a", ""}}, Quote: '"', AlwaysQuote: true, Output: "\"a\"|\"a\"|\"\"\n", Comma: '|'}, } func TestWrite(t *testing.T) { @@ -56,6 +59,10 @@ func TestWrite(t *testing.T) { if tt.Comma != 0 { f.Comma = tt.Comma } + if tt.Quote != 0 { + f.Quote = tt.Quote + } + f.AlwaysQuote = tt.AlwaysQuote err := f.WriteAll(tt.Input) if err != tt.Error { t.Errorf("Unexpected error:\ngot %v\nwant %v", err, tt.Error) diff --git a/pkg/s3select/csv/args.go b/pkg/s3select/csv/args.go index 17d8bcfc2..bd03fcd9f 100644 --- a/pkg/s3select/csv/args.go +++ b/pkg/s3select/csv/args.go @@ -18,8 +18,11 @@ package csv import ( "encoding/xml" + "errors" "fmt" + "io" "strings" + "unicode/utf8" ) const ( @@ -55,68 +58,64 @@ func (args *ReaderArgs) IsEmpty() bool { } // UnmarshalXML - decodes XML data. -func (args *ReaderArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - // Make subtype to avoid recursive UnmarshalXML(). - type subReaderArgs ReaderArgs - parsedArgs := subReaderArgs{} - if err := d.DecodeElement(&parsedArgs, &start); err != nil { - return err +func (args *ReaderArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) (err error) { + args.FileHeaderInfo = none + args.RecordDelimiter = defaultRecordDelimiter + args.FieldDelimiter = defaultFieldDelimiter + args.QuoteCharacter = defaultQuoteCharacter + args.QuoteEscapeCharacter = defaultQuoteEscapeCharacter + args.CommentCharacter = defaultCommentCharacter + args.AllowQuotedRecordDelimiter = false + + for { + // Read tokens from the XML document in a stream. + t, err := d.Token() + if err != nil { + if err == io.EOF { + break + } + return err + } + + switch se := t.(type) { + case xml.StartElement: + tagName := se.Name.Local + switch tagName { + case "AllowQuotedRecordDelimiter": + var b bool + if err = d.DecodeElement(&b, &se); err != nil { + return err + } + args.AllowQuotedRecordDelimiter = b + default: + var s string + if err = d.DecodeElement(&s, &se); err != nil { + return err + } + switch tagName { + case "FileHeaderInfo": + args.FileHeaderInfo = strings.ToLower(s) + case "RecordDelimiter": + args.RecordDelimiter = s + case "FieldDelimiter": + args.FieldDelimiter = s + case "QuoteCharacter": + if utf8.RuneCountInString(s) > 1 { + return fmt.Errorf("unsupported QuoteCharacter '%v'", s) + } + args.QuoteCharacter = s + // Not supported yet + case "QuoteEscapeCharacter": + case "Comments": + args.CommentCharacter = s + default: + return errors.New("unrecognized option") + } + } + } } - parsedArgs.FileHeaderInfo = strings.ToLower(parsedArgs.FileHeaderInfo) - switch parsedArgs.FileHeaderInfo { - case "": - parsedArgs.FileHeaderInfo = none - case none, use, ignore: - default: - return errInvalidFileHeaderInfo(fmt.Errorf("invalid FileHeaderInfo '%v'", parsedArgs.FileHeaderInfo)) - } - - switch len([]rune(parsedArgs.RecordDelimiter)) { - case 0: - parsedArgs.RecordDelimiter = defaultRecordDelimiter - case 1, 2: - default: - return fmt.Errorf("invalid RecordDelimiter '%v'", parsedArgs.RecordDelimiter) - } - - switch len([]rune(parsedArgs.FieldDelimiter)) { - case 0: - parsedArgs.FieldDelimiter = defaultFieldDelimiter - case 1: - default: - return fmt.Errorf("invalid FieldDelimiter '%v'", parsedArgs.FieldDelimiter) - } - - switch parsedArgs.QuoteCharacter { - case "": - parsedArgs.QuoteCharacter = defaultQuoteCharacter - case defaultQuoteCharacter: - default: - return fmt.Errorf("unsupported QuoteCharacter '%v'", parsedArgs.QuoteCharacter) - } - - switch parsedArgs.QuoteEscapeCharacter { - case "": - parsedArgs.QuoteEscapeCharacter = defaultQuoteEscapeCharacter - case defaultQuoteEscapeCharacter: - default: - return fmt.Errorf("unsupported QuoteEscapeCharacter '%v'", parsedArgs.QuoteEscapeCharacter) - } - - switch parsedArgs.CommentCharacter { - case "": - parsedArgs.CommentCharacter = defaultCommentCharacter - case defaultCommentCharacter: - default: - return fmt.Errorf("unsupported Comments '%v'", parsedArgs.CommentCharacter) - } - - if parsedArgs.AllowQuotedRecordDelimiter { - return fmt.Errorf("flag AllowQuotedRecordDelimiter is unsupported at the moment") - } - - *args = ReaderArgs(parsedArgs) + args.QuoteEscapeCharacter = args.QuoteCharacter args.unmarshaled = true return nil } @@ -138,55 +137,54 @@ func (args *WriterArgs) IsEmpty() bool { // UnmarshalXML - decodes XML data. func (args *WriterArgs) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error { - // Make subtype to avoid recursive UnmarshalXML(). - type subWriterArgs WriterArgs - parsedArgs := subWriterArgs{} - if err := d.DecodeElement(&parsedArgs, &start); err != nil { - return err - } - - parsedArgs.QuoteFields = strings.ToLower(parsedArgs.QuoteFields) - switch parsedArgs.QuoteFields { - case "": - parsedArgs.QuoteFields = asneeded - case always, asneeded: - default: - return errInvalidQuoteFields(fmt.Errorf("invalid QuoteFields '%v'", parsedArgs.QuoteFields)) - } - - switch len([]rune(parsedArgs.RecordDelimiter)) { - case 0: - parsedArgs.RecordDelimiter = defaultRecordDelimiter - case 1, 2: - default: - return fmt.Errorf("invalid RecordDelimiter '%v'", parsedArgs.RecordDelimiter) - } - - switch len([]rune(parsedArgs.FieldDelimiter)) { - case 0: - parsedArgs.FieldDelimiter = defaultFieldDelimiter - case 1: - default: - return fmt.Errorf("invalid FieldDelimiter '%v'", parsedArgs.FieldDelimiter) - } - - switch parsedArgs.QuoteCharacter { - case "": - parsedArgs.QuoteCharacter = defaultQuoteCharacter - case defaultQuoteCharacter: - default: - return fmt.Errorf("unsupported QuoteCharacter '%v'", parsedArgs.QuoteCharacter) - } - switch parsedArgs.QuoteEscapeCharacter { - case "": - parsedArgs.QuoteEscapeCharacter = defaultQuoteEscapeCharacter - case defaultQuoteEscapeCharacter: - default: - return fmt.Errorf("unsupported QuoteEscapeCharacter '%v'", parsedArgs.QuoteEscapeCharacter) + args.QuoteFields = asneeded + args.RecordDelimiter = defaultRecordDelimiter + args.FieldDelimiter = defaultFieldDelimiter + args.QuoteCharacter = defaultQuoteCharacter + args.QuoteEscapeCharacter = defaultQuoteCharacter + + for { + // Read tokens from the XML document in a stream. + t, err := d.Token() + if err != nil { + if err == io.EOF { + break + } + return err + } + + switch se := t.(type) { + case xml.StartElement: + var s string + if err = d.DecodeElement(&s, &se); err != nil { + return err + } + switch se.Name.Local { + case "QuoteFields": + args.QuoteFields = strings.ToLower(s) + case "RecordDelimiter": + args.RecordDelimiter = s + case "FieldDelimiter": + args.FieldDelimiter = s + case "QuoteCharacter": + switch utf8.RuneCountInString(s) { + case 0: + args.QuoteCharacter = "\x00" + case 1: + args.QuoteCharacter = s + default: + return fmt.Errorf("unsupported QuoteCharacter '%v'", s) + } + // Not supported yet + case "QuoteEscapeCharacter": + default: + return errors.New("unrecognized option") + } + } } - *args = WriterArgs(parsedArgs) + args.QuoteEscapeCharacter = args.QuoteCharacter args.unmarshaled = true return nil } diff --git a/pkg/s3select/csv/reader.go b/pkg/s3select/csv/reader.go index e60af9896..9d0fc1171 100644 --- a/pkg/s3select/csv/reader.go +++ b/pkg/s3select/csv/reader.go @@ -294,6 +294,11 @@ func NewReader(readCloser io.ReadCloser, args *ReaderArgs) (*Reader, error) { ret := csv.NewReader(r) ret.Comma = []rune(args.FieldDelimiter)[0] ret.Comment = []rune(args.CommentCharacter)[0] + ret.Quote = []rune{} + if len([]rune(args.QuoteCharacter)) > 0 { + // Add the first rune of args.QuoteChracter + ret.Quote = append(ret.Quote, []rune(args.QuoteCharacter)[0]) + } ret.FieldsPerRecord = -1 // If LazyQuotes is true, a quote may appear in an unquoted field and a // non-doubled quote may appear in a quoted field. diff --git a/pkg/s3select/csv/reader_test.go b/pkg/s3select/csv/reader_test.go index 81a39a5f9..9aed957b0 100644 --- a/pkg/s3select/csv/reader_test.go +++ b/pkg/s3select/csv/reader_test.go @@ -63,7 +63,7 @@ func TestRead(t *testing.T) { if err != nil { break } - record.WriteCSV(&result, []rune(c.fieldDelimiter)[0]) + record.WriteCSV(&result, []rune(c.fieldDelimiter)[0], '"', false) result.Truncate(result.Len() - 1) result.WriteString(c.recordDelimiter) } @@ -243,7 +243,7 @@ func TestReadExtended(t *testing.T) { } if fields < 10 { // Write with fixed delimiters, newlines. - err := record.WriteCSV(&result, ',') + err := record.WriteCSV(&result, ',', '"', false) if err != nil { t.Error(err) } @@ -454,7 +454,7 @@ func TestReadFailures(t *testing.T) { break } // Write with fixed delimiters, newlines. - err := record.WriteCSV(&result, ',') + err := record.WriteCSV(&result, ',', '"', false) if err != nil { t.Error(err) } diff --git a/pkg/s3select/csv/record.go b/pkg/s3select/csv/record.go index 5ad3c3374..7cda5e4bb 100644 --- a/pkg/s3select/csv/record.go +++ b/pkg/s3select/csv/record.go @@ -92,9 +92,11 @@ func (r *Record) Clone(dst sql.Record) sql.Record { } // WriteCSV - encodes to CSV data. -func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { +func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune, quote rune, alwaysQuote bool) error { w := csv.NewWriter(writer) w.Comma = fieldDelimiter + w.AlwaysQuote = alwaysQuote + w.Quote = quote if err := w.Write(r.csvRecord); err != nil { return err } diff --git a/pkg/s3select/json/record.go b/pkg/s3select/json/record.go index 9d211ec22..6410a0224 100644 --- a/pkg/s3select/json/record.go +++ b/pkg/s3select/json/record.go @@ -17,7 +17,6 @@ package json import ( - "encoding/csv" "encoding/json" "errors" "fmt" @@ -27,6 +26,7 @@ import ( "strings" "github.com/bcicen/jstream" + csv "github.com/minio/minio/pkg/csvparser" "github.com/minio/minio/pkg/s3select/sql" ) @@ -108,7 +108,7 @@ func (r *Record) Set(name string, value *sql.Value) (sql.Record, error) { } // WriteCSV - encodes to CSV data. -func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { +func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune, quote rune, alwaysQuote bool) error { var csvRecord []string for _, kv := range r.KVS { var columnValue string @@ -137,6 +137,8 @@ func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { w := csv.NewWriter(writer) w.Comma = fieldDelimiter + w.Quote = quote + w.AlwaysQuote = alwaysQuote if err := w.Write(csvRecord); err != nil { return err } diff --git a/pkg/s3select/select.go b/pkg/s3select/select.go index 07ac05193..1230d666b 100644 --- a/pkg/s3select/select.go +++ b/pkg/s3select/select.go @@ -353,7 +353,10 @@ func (s3Select *S3Select) marshal(buf *bytes.Buffer, record sql.Record) error { }() bufioWriter.Reset(buf) - err := record.WriteCSV(bufioWriter, []rune(s3Select.Output.CSVArgs.FieldDelimiter)[0]) + err := record.WriteCSV(bufioWriter, + []rune(s3Select.Output.CSVArgs.FieldDelimiter)[0], + []rune(s3Select.Output.CSVArgs.QuoteCharacter)[0], + strings.ToLower(s3Select.Output.CSVArgs.QuoteFields) == "always") if err != nil { return err } diff --git a/pkg/s3select/select_test.go b/pkg/s3select/select_test.go index b14a17ecd..d1113d9f0 100644 --- a/pkg/s3select/select_test.go +++ b/pkg/s3select/select_test.go @@ -252,6 +252,7 @@ func TestJSONQueries(t *testing.T) { + " @@ -587,6 +588,7 @@ func TestCSVQueries2(t *testing.T) { NONE USE + " diff --git a/pkg/s3select/simdj/reader_test.go b/pkg/s3select/simdj/reader_test.go index 36a296570..012beff55 100644 --- a/pkg/s3select/simdj/reader_test.go +++ b/pkg/s3select/simdj/reader_test.go @@ -131,11 +131,11 @@ func TestNDJSON(t *testing.T) { t.Error(err) } var gotB, wantB bytes.Buffer - err = rec.WriteCSV(&gotB, ',') + err = rec.WriteCSV(&gotB, ',', '"', false) if err != nil { t.Error(err) } - err = want.WriteCSV(&wantB, ',') + err = want.WriteCSV(&wantB, ',', '"', false) if err != nil { t.Error(err) } diff --git a/pkg/s3select/simdj/record.go b/pkg/s3select/simdj/record.go index d67c56d37..38ccafd81 100644 --- a/pkg/s3select/simdj/record.go +++ b/pkg/s3select/simdj/record.go @@ -17,10 +17,11 @@ package simdj import ( - "encoding/csv" "fmt" "io" + csv "github.com/minio/minio/pkg/csvparser" + "github.com/bcicen/jstream" "github.com/minio/minio/pkg/s3select/json" "github.com/minio/minio/pkg/s3select/sql" @@ -140,7 +141,7 @@ func (r *Record) Set(name string, value *sql.Value) (sql.Record, error) { } // WriteCSV - encodes to CSV data. -func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter rune) error { +func (r *Record) WriteCSV(writer io.Writer, fieldDelimiter, quote rune, alwaysQuote bool) error { csvRecord := make([]string, 0, 10) var tmp simdjson.Iter obj := r.object @@ -173,6 +174,8 @@ allElems: } w := csv.NewWriter(writer) w.Comma = fieldDelimiter + w.Quote = quote + w.AlwaysQuote = alwaysQuote if err := w.Write(csvRecord); err != nil { return err } diff --git a/pkg/s3select/sql/record.go b/pkg/s3select/sql/record.go index 4f34d73ee..8e375765c 100644 --- a/pkg/s3select/sql/record.go +++ b/pkg/s3select/sql/record.go @@ -46,7 +46,7 @@ type Record interface { // Set a value. // Can return a different record type. Set(name string, value *Value) (Record, error) - WriteCSV(writer io.Writer, fieldDelimiter rune) error + WriteCSV(writer io.Writer, fieldDelimiter, quote rune, alwaysQuote bool) error WriteJSON(writer io.Writer) error // Clone the record and if possible use the destination provided.