mapping.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. package binder
  2. import (
  3. "fmt"
  4. "maps"
  5. "mime/multipart"
  6. "reflect"
  7. "strings"
  8. "sync"
  9. utilsstrings "github.com/gofiber/utils/v2/strings"
  10. "github.com/valyala/bytebufferpool"
  11. "github.com/gofiber/schema"
  12. )
  13. // ParserConfig form decoder config for SetParserDecoder
  14. type ParserConfig struct {
  15. SetAliasTag string
  16. ParserType []ParserType
  17. IgnoreUnknownKeys bool
  18. ZeroEmpty bool
  19. }
  20. // ParserType require two element, type and converter for register.
  21. // Use ParserType with BodyParser for parsing custom type in form data.
  22. type ParserType struct {
  23. CustomType any
  24. Converter func(string) reflect.Value
  25. }
  26. var (
  27. decoderPoolMu sync.RWMutex
  28. // decoderPoolMap helps to improve binders
  29. decoderPoolMap = map[string]*sync.Pool{}
  30. // tags is used to classify parser's pool
  31. tags = []string{"header", "respHeader", "cookie", "query", "form", "uri"}
  32. )
  33. func getDecoderPool(tag string) *sync.Pool {
  34. decoderPoolMu.RLock()
  35. pool := decoderPoolMap[tag]
  36. if pool == nil {
  37. decoderPoolMu.RUnlock()
  38. panic(fmt.Sprintf("decoder pool not initialized for tag %q", tag))
  39. }
  40. decoderPoolMu.RUnlock()
  41. return pool
  42. }
  43. // SetParserDecoder allow globally change the option of form decoder, update decoderPool
  44. func SetParserDecoder(parserConfig ParserConfig) {
  45. decoderPoolMu.Lock()
  46. defer decoderPoolMu.Unlock()
  47. for _, tag := range tags {
  48. decoderPoolMap[tag] = &sync.Pool{New: func() any {
  49. return decoderBuilder(parserConfig)
  50. }}
  51. }
  52. }
  53. func decoderBuilder(parserConfig ParserConfig) any {
  54. decoder := schema.NewDecoder()
  55. decoder.IgnoreUnknownKeys(parserConfig.IgnoreUnknownKeys)
  56. if parserConfig.SetAliasTag != "" {
  57. decoder.SetAliasTag(parserConfig.SetAliasTag)
  58. }
  59. for _, v := range parserConfig.ParserType {
  60. decoder.RegisterConverter(reflect.ValueOf(v.CustomType).Interface(), v.Converter)
  61. }
  62. decoder.ZeroEmpty(parserConfig.ZeroEmpty)
  63. return decoder
  64. }
  65. func init() {
  66. decoderPoolMu.Lock()
  67. defer decoderPoolMu.Unlock()
  68. for _, tag := range tags {
  69. decoderPoolMap[tag] = &sync.Pool{New: func() any {
  70. return decoderBuilder(ParserConfig{
  71. IgnoreUnknownKeys: true,
  72. ZeroEmpty: true,
  73. })
  74. }}
  75. }
  76. }
  77. // parse data into the map or struct
  78. func parse(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
  79. ptrVal := reflect.ValueOf(out)
  80. // Get pointer value
  81. if ptrVal.Kind() == reflect.Ptr {
  82. ptrVal = ptrVal.Elem()
  83. }
  84. // Parse into the map
  85. if ptrVal.Kind() == reflect.Map && ptrVal.Type().Key().Kind() == reflect.String {
  86. return parseToMap(ptrVal, data)
  87. }
  88. // Parse into the struct
  89. return parseToStruct(aliasTag, out, data, files...)
  90. }
  91. // Parse data into the struct with gofiber/schema
  92. func parseToStruct(aliasTag string, out any, data map[string][]string, files ...map[string][]*multipart.FileHeader) error {
  93. // Get decoder from pool
  94. pool := getDecoderPool(aliasTag)
  95. schemaDecoder := pool.Get().(*schema.Decoder) //nolint:errcheck,forcetypeassert // not needed
  96. defer pool.Put(schemaDecoder)
  97. // Set alias tag
  98. schemaDecoder.SetAliasTag(aliasTag)
  99. if err := schemaDecoder.Decode(out, data, files...); err != nil {
  100. return fmt.Errorf("bind: %w", err)
  101. }
  102. return nil
  103. }
  104. // Parse data into the map
  105. // thanks to https://github.com/gin-gonic/gin/blob/master/binding/binding.go
  106. func parseToMap(target reflect.Value, data map[string][]string) error {
  107. if !target.IsValid() {
  108. return ErrInvalidDestinationValue
  109. }
  110. if target.Kind() == reflect.Interface && !target.IsNil() {
  111. target = target.Elem()
  112. }
  113. if target.Kind() != reflect.Map || target.Type().Key().Kind() != reflect.String {
  114. return nil // nothing to do for non-map destinations
  115. }
  116. if target.IsNil() {
  117. if !target.CanSet() {
  118. return ErrMapNilDestination
  119. }
  120. target.Set(reflect.MakeMap(target.Type()))
  121. }
  122. switch target.Type().Elem().Kind() {
  123. case reflect.Slice:
  124. newMap, ok := target.Interface().(map[string][]string)
  125. if !ok {
  126. return ErrMapNotConvertible
  127. }
  128. maps.Copy(newMap, data)
  129. case reflect.String:
  130. newMap, ok := target.Interface().(map[string]string)
  131. if !ok {
  132. return ErrMapNotConvertible
  133. }
  134. for k, v := range data {
  135. if len(v) == 0 {
  136. newMap[k] = ""
  137. continue
  138. }
  139. newMap[k] = v[len(v)-1]
  140. }
  141. default:
  142. // Interface element maps (e.g. map[string]any) are left untouched because
  143. // the binder cannot safely infer element conversions without mutating
  144. // caller-provided values. These destinations therefore see a successful
  145. // no-op parse.
  146. return nil // it's not necessary to check all types
  147. }
  148. return nil
  149. }
  150. func parseParamSquareBrackets(k string) (string, error) {
  151. bb := bytebufferpool.Get()
  152. defer bytebufferpool.Put(bb)
  153. kbytes := []byte(k)
  154. openBracketsCount := 0
  155. for i, b := range kbytes {
  156. if b == '[' {
  157. openBracketsCount++
  158. if i+1 < len(kbytes) && kbytes[i+1] != ']' {
  159. if err := bb.WriteByte('.'); err != nil {
  160. return "", err //nolint:wrapcheck // unnecessary to wrap it
  161. }
  162. }
  163. continue
  164. }
  165. if b == ']' {
  166. openBracketsCount--
  167. if openBracketsCount < 0 {
  168. return "", ErrUnmatchedBrackets
  169. }
  170. continue
  171. }
  172. if err := bb.WriteByte(b); err != nil {
  173. return "", err //nolint:wrapcheck // unnecessary to wrap it
  174. }
  175. }
  176. if openBracketsCount > 0 {
  177. return "", ErrUnmatchedBrackets
  178. }
  179. return bb.String(), nil
  180. }
  181. func isStringKeyMap(t reflect.Type) bool {
  182. return t.Kind() == reflect.Map && t.Key().Kind() == reflect.String
  183. }
  184. func isExported(f *reflect.StructField) bool {
  185. if f == nil {
  186. return false
  187. }
  188. return f.PkgPath == ""
  189. }
  190. func fieldName(f *reflect.StructField, aliasTag string) string {
  191. if f == nil {
  192. return ""
  193. }
  194. name := f.Tag.Get(aliasTag)
  195. if name == "" {
  196. name = f.Name
  197. } else if first, _, found := strings.Cut(name, ","); found {
  198. name = first
  199. }
  200. return utilsstrings.ToLower(name)
  201. }
  202. type fieldInfo struct {
  203. names map[string]reflect.Kind
  204. nestedKinds map[reflect.Kind]struct{}
  205. }
  206. func unwrapType(t reflect.Type) reflect.Type {
  207. for t.Kind() == reflect.Ptr {
  208. t = t.Elem()
  209. }
  210. return t
  211. }
  212. var (
  213. headerFieldCache sync.Map
  214. respHeaderFieldCache sync.Map
  215. cookieFieldCache sync.Map
  216. queryFieldCache sync.Map
  217. formFieldCache sync.Map
  218. uriFieldCache sync.Map
  219. )
  220. func getFieldCache(aliasTag string) *sync.Map {
  221. switch aliasTag {
  222. case "header":
  223. return &headerFieldCache
  224. case "respHeader":
  225. return &respHeaderFieldCache
  226. case "cookie":
  227. return &cookieFieldCache
  228. case "form":
  229. return &formFieldCache
  230. case "uri":
  231. return &uriFieldCache
  232. case "query":
  233. return &queryFieldCache
  234. }
  235. panic("unknown alias tag: " + aliasTag)
  236. }
  237. func buildFieldInfo(t reflect.Type, aliasTag string) fieldInfo {
  238. info := fieldInfo{
  239. names: make(map[string]reflect.Kind),
  240. nestedKinds: make(map[reflect.Kind]struct{}),
  241. }
  242. for i := 0; i < t.NumField(); i++ {
  243. f := t.Field(i)
  244. if !isExported(&f) {
  245. continue
  246. }
  247. fieldType := unwrapType(f.Type)
  248. info.names[fieldName(&f, aliasTag)] = fieldType.Kind()
  249. if fieldType.Kind() == reflect.Struct {
  250. for j := 0; j < fieldType.NumField(); j++ {
  251. sf := fieldType.Field(j)
  252. if !isExported(&sf) {
  253. continue
  254. }
  255. nestedType := unwrapType(sf.Type)
  256. info.nestedKinds[nestedType.Kind()] = struct{}{}
  257. }
  258. }
  259. }
  260. return info
  261. }
  262. func equalFieldType(out any, kind reflect.Kind, key, aliasTag string) bool {
  263. typ := reflect.TypeOf(out).Elem()
  264. key = utilsstrings.ToLower(key)
  265. if isStringKeyMap(typ) {
  266. return true
  267. }
  268. if typ.Kind() != reflect.Struct {
  269. return false
  270. }
  271. cache := getFieldCache(aliasTag)
  272. val, ok := cache.Load(typ)
  273. if !ok {
  274. info := buildFieldInfo(typ, aliasTag)
  275. val, _ = cache.LoadOrStore(typ, info)
  276. }
  277. info, ok := val.(fieldInfo)
  278. if !ok {
  279. return false
  280. }
  281. if k, ok := info.names[key]; ok && k == kind {
  282. return true
  283. }
  284. if _, ok := info.nestedKinds[kind]; ok {
  285. return true
  286. }
  287. return false
  288. }
  289. // FilterFlags returns the media type value by trimming any parameters from a Content-Type header.
  290. func FilterFlags(content string) string {
  291. if i := strings.IndexAny(content, " ;"); i >= 0 {
  292. return content[:i]
  293. }
  294. return content
  295. }
  296. func formatBindData[T, K any](aliasTag string, out any, data map[string][]T, key string, value K, enableSplitting, supportBracketNotation bool) error { //nolint:revive // it's okay
  297. var err error
  298. if supportBracketNotation && strings.IndexByte(key, '[') >= 0 {
  299. key, err = parseParamSquareBrackets(key)
  300. if err != nil {
  301. return err
  302. }
  303. }
  304. switch v := any(value).(type) {
  305. case string:
  306. dataMap, ok := any(data).(map[string][]string)
  307. if !ok {
  308. return fmt.Errorf("unsupported value type: %T", value)
  309. }
  310. assignBindData(aliasTag, out, dataMap, key, v, enableSplitting)
  311. case []string:
  312. dataMap, ok := any(data).(map[string][]string)
  313. if !ok {
  314. return fmt.Errorf("unsupported value type: %T", value)
  315. }
  316. for _, val := range v {
  317. assignBindData(aliasTag, out, dataMap, key, val, enableSplitting)
  318. }
  319. case []*multipart.FileHeader:
  320. for _, val := range v {
  321. valT, ok := any(val).(T)
  322. if !ok {
  323. return fmt.Errorf("unsupported value type: %T", value)
  324. }
  325. data[key] = append(data[key], valT)
  326. }
  327. default:
  328. return fmt.Errorf("unsupported value type: %T", value)
  329. }
  330. return err
  331. }
  332. func assignBindData(aliasTag string, out any, data map[string][]string, key, value string, enableSplitting bool) { //nolint:revive // it's okay
  333. if enableSplitting && strings.IndexByte(value, ',') >= 0 && equalFieldType(out, reflect.Slice, key, aliasTag) {
  334. for v := range strings.SplitSeq(value, ",") {
  335. data[key] = append(data[key], v)
  336. }
  337. } else {
  338. data[key] = append(data[key], value)
  339. }
  340. }