wmi.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. //go:build windows
  2. // +build windows
  3. /*
  4. Package wmi provides a WQL interface for WMI on Windows.
  5. Example code to print names of running processes:
  6. type Win32_Process struct {
  7. Name string
  8. }
  9. func main() {
  10. var dst []Win32_Process
  11. q := wmi.CreateQuery(&dst, "")
  12. err := wmi.Query(q, &dst)
  13. if err != nil {
  14. log.Fatal(err)
  15. }
  16. for i, v := range dst {
  17. println(i, v.Name)
  18. }
  19. }
  20. */
  21. package wmi
  22. import (
  23. "bytes"
  24. "errors"
  25. "fmt"
  26. "log"
  27. "os"
  28. "reflect"
  29. "runtime"
  30. "strconv"
  31. "strings"
  32. "sync"
  33. "time"
  34. "github.com/gofiber/fiber/v2/internal/go-ole"
  35. "github.com/gofiber/fiber/v2/internal/go-ole/oleutil"
  36. )
  37. var l = log.New(os.Stdout, "", log.LstdFlags)
  38. var (
  39. ErrInvalidEntityType = errors.New("wmi: invalid entity type")
  40. // ErrNilCreateObject is the error returned if CreateObject returns nil even
  41. // if the error was nil.
  42. ErrNilCreateObject = errors.New("wmi: create object returned nil")
  43. lock sync.Mutex
  44. )
  45. // S_FALSE is returned by CoInitializeEx if it was already called on this thread.
  46. const S_FALSE = 0x00000001
  47. // QueryNamespace invokes Query with the given namespace on the local machine.
  48. func QueryNamespace(query string, dst interface{}, namespace string) error {
  49. return Query(query, dst, nil, namespace)
  50. }
  51. // Query runs the WQL query and appends the values to dst.
  52. //
  53. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  54. // the query must have the same name in dst. Supported types are all signed and
  55. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  56. // Array types are not supported.
  57. //
  58. // By default, the local machine and default namespace are used. These can be
  59. // changed using connectServerArgs. See
  60. // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
  61. //
  62. // Query is a wrapper around DefaultClient.Query.
  63. func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  64. if DefaultClient.SWbemServicesClient == nil {
  65. return DefaultClient.Query(query, dst, connectServerArgs...)
  66. }
  67. return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
  68. }
  69. // A Client is an WMI query client.
  70. //
  71. // Its zero value (DefaultClient) is a usable client.
  72. type Client struct {
  73. // NonePtrZero specifies if nil values for fields which aren't pointers
  74. // should be returned as the field types zero value.
  75. //
  76. // Setting this to true allows stucts without pointer fields to be used
  77. // without the risk failure should a nil value returned from WMI.
  78. NonePtrZero bool
  79. // PtrNil specifies if nil values for pointer fields should be returned
  80. // as nil.
  81. //
  82. // Setting this to true will set pointer fields to nil where WMI
  83. // returned nil, otherwise the types zero value will be returned.
  84. PtrNil bool
  85. // AllowMissingFields specifies that struct fields not present in the
  86. // query result should not result in an error.
  87. //
  88. // Setting this to true allows custom queries to be used with full
  89. // struct definitions instead of having to define multiple structs.
  90. AllowMissingFields bool
  91. // SWbemServiceClient is an optional SWbemServices object that can be
  92. // initialized and then reused across multiple queries. If it is null
  93. // then the method will initialize a new temporary client each time.
  94. SWbemServicesClient *SWbemServices
  95. }
  96. // DefaultClient is the default Client and is used by Query, QueryNamespace
  97. var DefaultClient = &Client{}
  98. // Query runs the WQL query and appends the values to dst.
  99. //
  100. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  101. // the query must have the same name in dst. Supported types are all signed and
  102. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  103. // Array types are not supported.
  104. //
  105. // By default, the local machine and default namespace are used. These can be
  106. // changed using connectServerArgs. See
  107. // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
  108. func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  109. dv := reflect.ValueOf(dst)
  110. if dv.Kind() != reflect.Ptr || dv.IsNil() {
  111. return ErrInvalidEntityType
  112. }
  113. dv = dv.Elem()
  114. mat, elemType := checkMultiArg(dv)
  115. if mat == multiArgTypeInvalid {
  116. return ErrInvalidEntityType
  117. }
  118. lock.Lock()
  119. defer lock.Unlock()
  120. runtime.LockOSThread()
  121. defer runtime.UnlockOSThread()
  122. err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
  123. if err != nil {
  124. oleCode := err.(*ole.OleError).Code()
  125. if oleCode != ole.S_OK && oleCode != S_FALSE {
  126. return err
  127. }
  128. }
  129. defer ole.CoUninitialize()
  130. unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
  131. if err != nil {
  132. return err
  133. } else if unknown == nil {
  134. return ErrNilCreateObject
  135. }
  136. defer unknown.Release()
  137. wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
  138. if err != nil {
  139. return err
  140. }
  141. defer wmi.Release()
  142. // service is a SWbemServices
  143. serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
  144. if err != nil {
  145. return err
  146. }
  147. service := serviceRaw.ToIDispatch()
  148. defer serviceRaw.Clear()
  149. // result is a SWBemObjectSet
  150. resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
  151. if err != nil {
  152. return err
  153. }
  154. result := resultRaw.ToIDispatch()
  155. defer resultRaw.Clear()
  156. count, err := oleInt64(result, "Count")
  157. if err != nil {
  158. return err
  159. }
  160. enumProperty, err := result.GetProperty("_NewEnum")
  161. if err != nil {
  162. return err
  163. }
  164. defer enumProperty.Clear()
  165. enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
  166. if err != nil {
  167. return err
  168. }
  169. if enum == nil {
  170. return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
  171. }
  172. defer enum.Release()
  173. // Initialize a slice with Count capacity
  174. dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
  175. var errFieldMismatch error
  176. for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
  177. if err != nil {
  178. return err
  179. }
  180. err := func() error {
  181. // item is a SWbemObject, but really a Win32_Process
  182. item := itemRaw.ToIDispatch()
  183. defer item.Release()
  184. ev := reflect.New(elemType)
  185. if err = c.loadEntity(ev.Interface(), item); err != nil {
  186. if _, ok := err.(*ErrFieldMismatch); ok {
  187. // We continue loading entities even in the face of field mismatch errors.
  188. // If we encounter any other error, that other error is returned. Otherwise,
  189. // an ErrFieldMismatch is returned.
  190. errFieldMismatch = err
  191. } else {
  192. return err
  193. }
  194. }
  195. if mat != multiArgTypeStructPtr {
  196. ev = ev.Elem()
  197. }
  198. dv.Set(reflect.Append(dv, ev))
  199. return nil
  200. }()
  201. if err != nil {
  202. return err
  203. }
  204. }
  205. return errFieldMismatch
  206. }
  207. // ErrFieldMismatch is returned when a field is to be loaded into a different
  208. // type than the one it was stored from, or when a field is missing or
  209. // unexported in the destination struct.
  210. // StructType is the type of the struct pointed to by the destination argument.
  211. type ErrFieldMismatch struct {
  212. StructType reflect.Type
  213. FieldName string
  214. Reason string
  215. }
  216. func (e *ErrFieldMismatch) Error() string {
  217. return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
  218. e.FieldName, e.StructType, e.Reason)
  219. }
  220. var timeType = reflect.TypeOf(time.Time{})
  221. // loadEntity loads a SWbemObject into a struct pointer.
  222. func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
  223. v := reflect.ValueOf(dst).Elem()
  224. for i := 0; i < v.NumField(); i++ {
  225. f := v.Field(i)
  226. of := f
  227. isPtr := f.Kind() == reflect.Ptr
  228. if isPtr {
  229. ptr := reflect.New(f.Type().Elem())
  230. f.Set(ptr)
  231. f = f.Elem()
  232. }
  233. n := v.Type().Field(i).Name
  234. if !f.CanSet() {
  235. return &ErrFieldMismatch{
  236. StructType: of.Type(),
  237. FieldName: n,
  238. Reason: "CanSet() is false",
  239. }
  240. }
  241. prop, err := oleutil.GetProperty(src, n)
  242. if err != nil {
  243. if !c.AllowMissingFields {
  244. errFieldMismatch = &ErrFieldMismatch{
  245. StructType: of.Type(),
  246. FieldName: n,
  247. Reason: "no such struct field",
  248. }
  249. }
  250. continue
  251. }
  252. defer prop.Clear()
  253. if prop.VT == 0x1 { // VT_NULL
  254. continue
  255. }
  256. switch val := prop.Value().(type) {
  257. case int8, int16, int32, int64, int:
  258. v := reflect.ValueOf(val).Int()
  259. switch f.Kind() {
  260. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  261. f.SetInt(v)
  262. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  263. f.SetUint(uint64(v))
  264. default:
  265. return &ErrFieldMismatch{
  266. StructType: of.Type(),
  267. FieldName: n,
  268. Reason: "not an integer class",
  269. }
  270. }
  271. case uint8, uint16, uint32, uint64:
  272. v := reflect.ValueOf(val).Uint()
  273. switch f.Kind() {
  274. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  275. f.SetInt(int64(v))
  276. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  277. f.SetUint(v)
  278. default:
  279. return &ErrFieldMismatch{
  280. StructType: of.Type(),
  281. FieldName: n,
  282. Reason: "not an integer class",
  283. }
  284. }
  285. case string:
  286. switch f.Kind() {
  287. case reflect.String:
  288. f.SetString(val)
  289. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  290. iv, err := strconv.ParseInt(val, 10, 64)
  291. if err != nil {
  292. return err
  293. }
  294. f.SetInt(iv)
  295. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  296. uv, err := strconv.ParseUint(val, 10, 64)
  297. if err != nil {
  298. return err
  299. }
  300. f.SetUint(uv)
  301. case reflect.Struct:
  302. switch f.Type() {
  303. case timeType:
  304. if len(val) == 25 {
  305. mins, err := strconv.Atoi(val[22:])
  306. if err != nil {
  307. return err
  308. }
  309. val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
  310. }
  311. t, err := time.Parse("20060102150405.000000-0700", val)
  312. if err != nil {
  313. return err
  314. }
  315. f.Set(reflect.ValueOf(t))
  316. }
  317. }
  318. case bool:
  319. switch f.Kind() {
  320. case reflect.Bool:
  321. f.SetBool(val)
  322. default:
  323. return &ErrFieldMismatch{
  324. StructType: of.Type(),
  325. FieldName: n,
  326. Reason: "not a bool",
  327. }
  328. }
  329. case float32:
  330. switch f.Kind() {
  331. case reflect.Float32:
  332. f.SetFloat(float64(val))
  333. default:
  334. return &ErrFieldMismatch{
  335. StructType: of.Type(),
  336. FieldName: n,
  337. Reason: "not a Float32",
  338. }
  339. }
  340. default:
  341. if f.Kind() == reflect.Slice {
  342. switch f.Type().Elem().Kind() {
  343. case reflect.String:
  344. safeArray := prop.ToArray()
  345. if safeArray != nil {
  346. arr := safeArray.ToValueArray()
  347. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  348. for i, v := range arr {
  349. s := fArr.Index(i)
  350. s.SetString(v.(string))
  351. }
  352. f.Set(fArr)
  353. }
  354. case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
  355. safeArray := prop.ToArray()
  356. if safeArray != nil {
  357. arr := safeArray.ToValueArray()
  358. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  359. for i, v := range arr {
  360. s := fArr.Index(i)
  361. s.SetUint(reflect.ValueOf(v).Uint())
  362. }
  363. f.Set(fArr)
  364. }
  365. case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
  366. safeArray := prop.ToArray()
  367. if safeArray != nil {
  368. arr := safeArray.ToValueArray()
  369. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  370. for i, v := range arr {
  371. s := fArr.Index(i)
  372. s.SetInt(reflect.ValueOf(v).Int())
  373. }
  374. f.Set(fArr)
  375. }
  376. default:
  377. return &ErrFieldMismatch{
  378. StructType: of.Type(),
  379. FieldName: n,
  380. Reason: fmt.Sprintf("unsupported slice type (%T)", val),
  381. }
  382. }
  383. } else {
  384. typeof := reflect.TypeOf(val)
  385. if typeof == nil && (isPtr || c.NonePtrZero) {
  386. if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
  387. of.Set(reflect.Zero(of.Type()))
  388. }
  389. break
  390. }
  391. return &ErrFieldMismatch{
  392. StructType: of.Type(),
  393. FieldName: n,
  394. Reason: fmt.Sprintf("unsupported type (%T)", val),
  395. }
  396. }
  397. }
  398. }
  399. return errFieldMismatch
  400. }
  401. type multiArgType int
  402. const (
  403. multiArgTypeInvalid multiArgType = iota
  404. multiArgTypeStruct
  405. multiArgTypeStructPtr
  406. )
  407. // checkMultiArg checks that v has type []S, []*S for some struct type S.
  408. //
  409. // It returns what category the slice's elements are, and the reflect.Type
  410. // that represents S.
  411. func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
  412. if v.Kind() != reflect.Slice {
  413. return multiArgTypeInvalid, nil
  414. }
  415. elemType = v.Type().Elem()
  416. switch elemType.Kind() {
  417. case reflect.Struct:
  418. return multiArgTypeStruct, elemType
  419. case reflect.Ptr:
  420. elemType = elemType.Elem()
  421. if elemType.Kind() == reflect.Struct {
  422. return multiArgTypeStructPtr, elemType
  423. }
  424. }
  425. return multiArgTypeInvalid, nil
  426. }
  427. func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
  428. v, err := oleutil.GetProperty(item, prop)
  429. if err != nil {
  430. return 0, err
  431. }
  432. defer v.Clear()
  433. i := int64(v.Val)
  434. return i, nil
  435. }
  436. // CreateQuery returns a WQL query string that queries all columns of src. where
  437. // is an optional string that is appended to the query, to be used with WHERE
  438. // clauses. In such a case, the "WHERE" string should appear at the beginning.
  439. func CreateQuery(src interface{}, where string) string {
  440. var b bytes.Buffer
  441. b.WriteString("SELECT ")
  442. s := reflect.Indirect(reflect.ValueOf(src))
  443. t := s.Type()
  444. if s.Kind() == reflect.Slice {
  445. t = t.Elem()
  446. }
  447. if t.Kind() != reflect.Struct {
  448. return ""
  449. }
  450. var fields []string
  451. for i := 0; i < t.NumField(); i++ {
  452. fields = append(fields, t.Field(i).Name)
  453. }
  454. b.WriteString(strings.Join(fields, ", "))
  455. b.WriteString(" FROM ")
  456. b.WriteString(t.Name())
  457. b.WriteString(" " + where)
  458. return b.String()
  459. }