// mgo - MongoDB driver for Go // // Copyright (c) 2010-2012 - Gustavo Niemeyer <gustavo@niemeyer.net> // // All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are met: // // 1. Redistributions of source code must retain the above copyright notice, this // list of conditions and the following disclaimer. // 2. Redistributions in binary form must reproduce the above copyright notice, // this list of conditions and the following disclaimer in the documentation // and/or other materials provided with the distribution. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND // ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED // WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE // DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR // ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES // (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND // ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS // SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package mgo import ( "errors" "fmt" "net" "sync" "time" "gopkg.in/mgo.v2/bson" ) type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) type mongoSocket struct { sync.Mutex server *mongoServer // nil when cached conn net.Conn timeout time.Duration addr string // For debugging only. nextRequestId uint32 replyFuncs map[uint32]replyFunc references int creds []Credential logout []Credential cachedNonce string gotNonce sync.Cond dead error serverInfo *mongoServerInfo } type queryOpFlags uint32 const ( _ queryOpFlags = 1 << iota flagTailable flagSlaveOk flagLogReplay flagNoCursorTimeout flagAwaitData ) type queryOp struct { collection string query interface{} skip int32 limit int32 selector interface{} flags queryOpFlags replyFunc replyFunc mode Mode options queryWrapper hasOptions bool serverTags []bson.D } type queryWrapper struct { Query interface{} "$query" OrderBy interface{} "$orderby,omitempty" Hint interface{} "$hint,omitempty" Explain bool "$explain,omitempty" Snapshot bool "$snapshot,omitempty" ReadPreference bson.D "$readPreference,omitempty" MaxScan int "$maxScan,omitempty" MaxTimeMS int "$maxTimeMS,omitempty" Comment string "$comment,omitempty" } func (op *queryOp) finalQuery(socket *mongoSocket) interface{} { if socket.ServerInfo().Mongos { var modeName string if op.flags&flagSlaveOk == 0 { modeName = "primary" } else { switch op.mode { case Strong: modeName = "primary" case Monotonic, Eventual: modeName = "secondaryPreferred" case PrimaryPreferred: modeName = "primaryPreferred" case Secondary: modeName = "secondary" case SecondaryPreferred: modeName = "secondaryPreferred" case Nearest: modeName = "nearest" default: panic(fmt.Sprintf("unsupported read mode: %d", op.mode)) } } op.hasOptions = true op.options.ReadPreference = make(bson.D, 0, 2) op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"mode", modeName}) if len(op.serverTags) > 0 { op.options.ReadPreference = append(op.options.ReadPreference, bson.DocElem{"tags", op.serverTags}) } } if op.hasOptions { if op.query == nil { var empty bson.D op.options.Query = empty } else { op.options.Query = op.query } debugf("final query is %#v\n", &op.options) return &op.options } return op.query } type getMoreOp struct { collection string limit int32 cursorId int64 replyFunc replyFunc } type replyOp struct { flags uint32 cursorId int64 firstDoc int32 replyDocs int32 } type insertOp struct { collection string // "database.collection" documents []interface{} // One or more documents to insert flags uint32 } type updateOp struct { Collection string `bson:"-"` // "database.collection" Selector interface{} `bson:"q"` Update interface{} `bson:"u"` Flags uint32 `bson:"-"` Multi bool `bson:"multi,omitempty"` Upsert bool `bson:"upsert,omitempty"` } type deleteOp struct { collection string // "database.collection" selector interface{} flags uint32 } type killCursorsOp struct { cursorIds []int64 } type requestInfo struct { bufferPos int replyFunc replyFunc } func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { socket := &mongoSocket{ conn: conn, addr: server.Addr, server: server, replyFuncs: make(map[uint32]replyFunc), } socket.gotNonce.L = &socket.Mutex if err := socket.InitialAcquire(server.Info(), timeout); err != nil { panic("newSocket: InitialAcquire returned error: " + err.Error()) } stats.socketsAlive(+1) debugf("Socket %p to %s: initialized", socket, socket.addr) socket.resetNonce() go socket.readLoop() return socket } // Server returns the server that the socket is associated with. // It returns nil while the socket is cached in its respective server. func (socket *mongoSocket) Server() *mongoServer { socket.Lock() server := socket.server socket.Unlock() return server } // ServerInfo returns details for the server at the time the socket // was initially acquired. func (socket *mongoSocket) ServerInfo() *mongoServerInfo { socket.Lock() serverInfo := socket.serverInfo socket.Unlock() return serverInfo } // InitialAcquire obtains the first reference to the socket, either // right after the connection is made or once a recycled socket is // being put back in use. func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { socket.Lock() if socket.references > 0 { panic("Socket acquired out of cache with references") } if socket.dead != nil { dead := socket.dead socket.Unlock() return dead } socket.references++ socket.serverInfo = serverInfo socket.timeout = timeout stats.socketsInUse(+1) stats.socketRefs(+1) socket.Unlock() return nil } // Acquire obtains an additional reference to the socket. // The socket will only be recycled when it's released as many // times as it's been acquired. func (socket *mongoSocket) Acquire() (info *mongoServerInfo) { socket.Lock() if socket.references == 0 { panic("Socket got non-initial acquire with references == 0") } // We'll track references to dead sockets as well. // Caller is still supposed to release the socket. socket.references++ stats.socketRefs(+1) serverInfo := socket.serverInfo socket.Unlock() return serverInfo } // Release decrements a socket reference. The socket will be // recycled once its released as many times as it's been acquired. func (socket *mongoSocket) Release() { socket.Lock() if socket.references == 0 { panic("socket.Release() with references == 0") } socket.references-- stats.socketRefs(-1) if socket.references == 0 { stats.socketsInUse(-1) server := socket.server socket.Unlock() socket.LogoutAll() // If the socket is dead server is nil. if server != nil { server.RecycleSocket(socket) } } else { socket.Unlock() } } // SetTimeout changes the timeout used on socket operations. func (socket *mongoSocket) SetTimeout(d time.Duration) { socket.Lock() socket.timeout = d socket.Unlock() } type deadlineType int const ( readDeadline deadlineType = 1 writeDeadline deadlineType = 2 ) func (socket *mongoSocket) updateDeadline(which deadlineType) { var when time.Time if socket.timeout > 0 { when = time.Now().Add(socket.timeout) } whichstr := "" switch which { case readDeadline | writeDeadline: whichstr = "read/write" socket.conn.SetDeadline(when) case readDeadline: whichstr = "read" socket.conn.SetReadDeadline(when) case writeDeadline: whichstr = "write" socket.conn.SetWriteDeadline(when) default: panic("invalid parameter to updateDeadline") } debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) } // Close terminates the socket use. func (socket *mongoSocket) Close() { socket.kill(errors.New("Closed explicitly"), false) } func (socket *mongoSocket) kill(err error, abend bool) { socket.Lock() if socket.dead != nil { debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error()) socket.Unlock() return } logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend) socket.dead = err socket.conn.Close() stats.socketsAlive(-1) replyFuncs := socket.replyFuncs socket.replyFuncs = make(map[uint32]replyFunc) server := socket.server socket.server = nil socket.gotNonce.Broadcast() socket.Unlock() for _, replyFunc := range replyFuncs { logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error()) replyFunc(err, nil, -1, nil) } if abend { server.AbendSocket(socket) } } func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) { var wait, change sync.Mutex var replyDone bool var replyData []byte var replyErr error wait.Lock() op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { change.Lock() if !replyDone { replyDone = true replyErr = err if err == nil { replyData = docData } } change.Unlock() wait.Unlock() } err = socket.Query(op) if err != nil { return nil, err } wait.Lock() change.Lock() data = replyData err = replyErr change.Unlock() return data, err } func (socket *mongoSocket) Query(ops ...interface{}) (err error) { if lops := socket.flushLogout(); len(lops) > 0 { ops = append(lops, ops...) } buf := make([]byte, 0, 256) // Serialize operations synchronously to avoid interrupting // other goroutines while we can't really be sending data. // Also, record id positions so that we can compute request // ids at once later with the lock already held. requests := make([]requestInfo, len(ops)) requestCount := 0 for _, op := range ops { debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op) start := len(buf) var replyFunc replyFunc switch op := op.(type) { case *updateOp: buf = addHeader(buf, 2001) buf = addInt32(buf, 0) // Reserved buf = addCString(buf, op.Collection) buf = addInt32(buf, int32(op.Flags)) debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.Selector) buf, err = addBSON(buf, op.Selector) if err != nil { return err } debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.Update) buf, err = addBSON(buf, op.Update) if err != nil { return err } case *insertOp: buf = addHeader(buf, 2002) buf = addInt32(buf, int32(op.flags)) buf = addCString(buf, op.collection) for _, doc := range op.documents { debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc) buf, err = addBSON(buf, doc) if err != nil { return err } } case *queryOp: buf = addHeader(buf, 2004) buf = addInt32(buf, int32(op.flags)) buf = addCString(buf, op.collection) buf = addInt32(buf, op.skip) buf = addInt32(buf, op.limit) buf, err = addBSON(buf, op.finalQuery(socket)) if err != nil { return err } if op.selector != nil { buf, err = addBSON(buf, op.selector) if err != nil { return err } } replyFunc = op.replyFunc case *getMoreOp: buf = addHeader(buf, 2005) buf = addInt32(buf, 0) // Reserved buf = addCString(buf, op.collection) buf = addInt32(buf, op.limit) buf = addInt64(buf, op.cursorId) replyFunc = op.replyFunc case *deleteOp: buf = addHeader(buf, 2006) buf = addInt32(buf, 0) // Reserved buf = addCString(buf, op.collection) buf = addInt32(buf, int32(op.flags)) debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector) buf, err = addBSON(buf, op.selector) if err != nil { return err } case *killCursorsOp: buf = addHeader(buf, 2007) buf = addInt32(buf, 0) // Reserved buf = addInt32(buf, int32(len(op.cursorIds))) for _, cursorId := range op.cursorIds { buf = addInt64(buf, cursorId) } default: panic("internal error: unknown operation type") } setInt32(buf, start, int32(len(buf)-start)) if replyFunc != nil { request := &requests[requestCount] request.replyFunc = replyFunc request.bufferPos = start requestCount++ } } // Buffer is ready for the pipe. Lock, allocate ids, and enqueue. socket.Lock() if socket.dead != nil { dead := socket.dead socket.Unlock() debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) // XXX This seems necessary in case the session is closed concurrently // with a query being performed, but it's not yet tested: for i := 0; i != requestCount; i++ { request := &requests[i] if request.replyFunc != nil { request.replyFunc(dead, nil, -1, nil) } } return dead } wasWaiting := len(socket.replyFuncs) > 0 // Reserve id 0 for requests which should have no responses. requestId := socket.nextRequestId + 1 if requestId == 0 { requestId++ } socket.nextRequestId = requestId + uint32(requestCount) for i := 0; i != requestCount; i++ { request := &requests[i] setInt32(buf, request.bufferPos+4, int32(requestId)) socket.replyFuncs[requestId] = request.replyFunc requestId++ } debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf)) stats.sentOps(len(ops)) socket.updateDeadline(writeDeadline) _, err = socket.conn.Write(buf) if !wasWaiting && requestCount > 0 { socket.updateDeadline(readDeadline) } socket.Unlock() return err } func fill(r net.Conn, b []byte) error { l := len(b) n, err := r.Read(b) for n != l && err == nil { var ni int ni, err = r.Read(b[n:]) n += ni } return err } // Estimated minimum cost per socket: 1 goroutine + memory for the largest // document ever seen. func (socket *mongoSocket) readLoop() { p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields s := make([]byte, 4) conn := socket.conn // No locking, conn never changes. for { // XXX Handle timeouts, , etc err := fill(conn, p) if err != nil { socket.kill(err, true) return } totalLen := getInt32(p, 0) responseTo := getInt32(p, 8) opCode := getInt32(p, 12) // Don't use socket.server.Addr here. socket is not // locked and socket.server may go away. debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen) _ = totalLen if opCode != 1 { socket.kill(errors.New("opcode != 1, corrupted data?"), true) return } reply := replyOp{ flags: uint32(getInt32(p, 16)), cursorId: getInt64(p, 20), firstDoc: getInt32(p, 28), replyDocs: getInt32(p, 32), } stats.receivedOps(+1) stats.receivedDocs(int(reply.replyDocs)) socket.Lock() replyFunc, ok := socket.replyFuncs[uint32(responseTo)] if ok { delete(socket.replyFuncs, uint32(responseTo)) } socket.Unlock() if replyFunc != nil && reply.replyDocs == 0 { replyFunc(nil, &reply, -1, nil) } else { for i := 0; i != int(reply.replyDocs); i++ { err := fill(conn, s) if err != nil { if replyFunc != nil { replyFunc(err, nil, -1, nil) } socket.kill(err, true) return } b := make([]byte, int(getInt32(s, 0))) // copy(b, s) in an efficient way. b[0] = s[0] b[1] = s[1] b[2] = s[2] b[3] = s[3] err = fill(conn, b[4:]) if err != nil { if replyFunc != nil { replyFunc(err, nil, -1, nil) } socket.kill(err, true) return } if globalDebug && globalLogger != nil { m := bson.M{} if err := bson.Unmarshal(b, m); err == nil { debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m) } } if replyFunc != nil { replyFunc(nil, &reply, i, b) } // XXX Do bound checking against totalLen. } } socket.Lock() if len(socket.replyFuncs) == 0 { // Nothing else to read for now. Disable deadline. socket.conn.SetReadDeadline(time.Time{}) } else { socket.updateDeadline(readDeadline) } socket.Unlock() // XXX Do bound checking against totalLen. } } var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} func addHeader(b []byte, opcode int) []byte { i := len(b) b = append(b, emptyHeader...) // Enough for current opcodes. b[i+12] = byte(opcode) b[i+13] = byte(opcode >> 8) return b } func addInt32(b []byte, i int32) []byte { return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24)) } func addInt64(b []byte, i int64) []byte { return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) } func addCString(b []byte, s string) []byte { b = append(b, []byte(s)...) b = append(b, 0) return b } func addBSON(b []byte, doc interface{}) ([]byte, error) { if doc == nil { return append(b, 5, 0, 0, 0, 0), nil } data, err := bson.Marshal(doc) if err != nil { return b, err } return append(b, data...), nil } func setInt32(b []byte, pos int, i int32) { b[pos] = byte(i) b[pos+1] = byte(i >> 8) b[pos+2] = byte(i >> 16) b[pos+3] = byte(i >> 24) } func getInt32(b []byte, pos int) int32 { return (int32(b[pos+0])) | (int32(b[pos+1]) << 8) | (int32(b[pos+2]) << 16) | (int32(b[pos+3]) << 24) } func getInt64(b []byte, pos int) int64 { return (int64(b[pos+0])) | (int64(b[pos+1]) << 8) | (int64(b[pos+2]) << 16) | (int64(b[pos+3]) << 24) | (int64(b[pos+4]) << 32) | (int64(b[pos+5]) << 40) | (int64(b[pos+6]) << 48) | (int64(b[pos+7]) << 56) }