filterfs.go 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. package filterfs
  2. import "io/fs"
  3. import "fmt"
  4. // Filter returns true to include a file name and false to exclude it.
  5. type Filter func(name string) bool
  6. type FS struct {
  7. underlying fs.FS
  8. filter Filter
  9. }
  10. func Apply(sys fs.FS, filter Filter) FS {
  11. return FS{sys, filter}
  12. }
  13. type File struct {
  14. underlying fs.ReadDirFile
  15. filter Filter
  16. }
  17. func (f File) Stat() (fs.FileInfo, error) {
  18. return f.underlying.Stat()
  19. }
  20. func (f File) Read(buf []byte) (int, error) {
  21. return f.underlying.Read(buf)
  22. }
  23. func (f File) Close() error {
  24. return f.underlying.Close()
  25. }
  26. func filterEntries(filter Filter, entries []fs.DirEntry) []fs.DirEntry {
  27. result := []fs.DirEntry{}
  28. for _, entry := range entries {
  29. if filter(entry.Name()) {
  30. result = append(result, entry)
  31. }
  32. }
  33. return result
  34. }
  35. func (f File) ReadDir(n int) ([]fs.DirEntry, error) {
  36. entries, err := f.underlying.ReadDir(n)
  37. if err != nil {
  38. return entries, err
  39. }
  40. return filterEntries(f.filter, entries), nil
  41. }
  42. func (sys FS) Open(name string) (fs.File, error) {
  43. file, err := sys.underlying.Open(name)
  44. if err != nil {
  45. return nil, err
  46. }
  47. if dir, ok := file.(fs.ReadDirFile); ok {
  48. return File{dir, sys.filter}, nil
  49. } else {
  50. return file, nil
  51. }
  52. }
  53. func (sys FS) Glob(pattern string) ([]string, error) {
  54. entries, err := fs.Glob(sys.underlying, pattern)
  55. if err != nil {
  56. return entries, err
  57. }
  58. result := []string{}
  59. for _, entry := range entries {
  60. if sys.filter(entry) {
  61. result = append(result, entry)
  62. }
  63. }
  64. fmt.Printf("%v\n", entries)
  65. fmt.Printf("%v\n", result)
  66. return result, nil
  67. }
  68. /*
  69. func (sys FS) ReadDir(fname string) ([]fs.DirEntry, error) {
  70. entries, err := fs.ReadDir(sys.underlying, fname)
  71. if err != nil {
  72. return entries, err
  73. }
  74. return filterEntries(sys.filter, entries), nil
  75. }
  76. */
  77. var _ fs.FS = &FS{}
  78. var _ fs.GlobFS = &FS{}
  79. var _ fs.ReadDirFS = &FS{}