diff options
Diffstat (limited to 'vendor/google.golang.org/grpc/transport/http_util.go')
-rw-r--r-- | vendor/google.golang.org/grpc/transport/http_util.go | 54 |
1 files changed, 36 insertions, 18 deletions
diff --git a/vendor/google.golang.org/grpc/transport/http_util.go b/vendor/google.golang.org/grpc/transport/http_util.go index a35586608..7d15c7d74 100644 --- a/vendor/google.golang.org/grpc/transport/http_util.go +++ b/vendor/google.golang.org/grpc/transport/http_util.go @@ -28,6 +28,7 @@ import ( "strconv" "strings" "time" + "unicode/utf8" "github.com/golang/protobuf/proto" "golang.org/x/net/http2" @@ -437,16 +438,17 @@ func decodeTimeout(s string) (time.Duration, error) { const ( spaceByte = ' ' - tildaByte = '~' + tildeByte = '~' percentByte = '%' ) // encodeGrpcMessage is used to encode status code in header field -// "grpc-message". -// It checks to see if each individual byte in msg is an -// allowable byte, and then either percent encoding or passing it through. -// When percent encoding, the byte is converted into hexadecimal notation -// with a '%' prepended. +// "grpc-message". It does percent encoding and also replaces invalid utf-8 +// characters with Unicode replacement character. +// +// It checks to see if each individual byte in msg is an allowable byte, and +// then either percent encoding or passing it through. When percent encoding, +// the byte is converted into hexadecimal notation with a '%' prepended. func encodeGrpcMessage(msg string) string { if msg == "" { return "" @@ -454,7 +456,7 @@ func encodeGrpcMessage(msg string) string { lenMsg := len(msg) for i := 0; i < lenMsg; i++ { c := msg[i] - if !(c >= spaceByte && c < tildaByte && c != percentByte) { + if !(c >= spaceByte && c <= tildeByte && c != percentByte) { return encodeGrpcMessageUnchecked(msg) } } @@ -463,14 +465,26 @@ func encodeGrpcMessage(msg string) string { func encodeGrpcMessageUnchecked(msg string) string { var buf bytes.Buffer - lenMsg := len(msg) - for i := 0; i < lenMsg; i++ { - c := msg[i] - if c >= spaceByte && c < tildaByte && c != percentByte { - buf.WriteByte(c) - } else { - buf.WriteString(fmt.Sprintf("%%%02X", c)) + for len(msg) > 0 { + r, size := utf8.DecodeRuneInString(msg) + for _, b := range []byte(string(r)) { + if size > 1 { + // If size > 1, r is not ascii. Always do percent encoding. + buf.WriteString(fmt.Sprintf("%%%02X", b)) + continue + } + + // The for loop is necessary even if size == 1. r could be + // utf8.RuneError. + // + // fmt.Sprintf("%%%02X", utf8.RuneError) gives "%FFFD". + if b >= spaceByte && b <= tildeByte && b != percentByte { + buf.WriteByte(b) + } else { + buf.WriteString(fmt.Sprintf("%%%02X", b)) + } } + msg = msg[size:] } return buf.String() } @@ -531,10 +545,14 @@ func (w *bufWriter) Write(b []byte) (n int, err error) { if w.err != nil { return 0, w.err } - n = copy(w.buf[w.offset:], b) - w.offset += n - if w.offset >= w.batchSize { - err = w.Flush() + for len(b) > 0 { + nn := copy(w.buf[w.offset:], b) + b = b[nn:] + w.offset += nn + n += nn + if w.offset >= w.batchSize { + err = w.Flush() + } } return n, err } |