From 2b7f63e76b48185cf054e2937dbe3544d33ce89c Mon Sep 17 00:00:00 2001 From: Stevie Robinson Date: Sat, 4 May 2024 12:53:27 +0200 Subject: [PATCH] allow multiple protocols in custom filter --- .../src/Store/Actions/blocklistActions.js | 2 +- .../Blocklisting/BlocklistRepository.cs | 27 ++++++++++++------- .../Blocklisting/BlocklistService.cs | 6 ++--- .../Blocklist/BlocklistController.cs | 4 +-- 4 files changed, 24 insertions(+), 15 deletions(-) diff --git a/frontend/src/Store/Actions/blocklistActions.js b/frontend/src/Store/Actions/blocklistActions.js index f422bc095..6303ad2d1 100644 --- a/frontend/src/Store/Actions/blocklistActions.js +++ b/frontend/src/Store/Actions/blocklistActions.js @@ -97,7 +97,7 @@ export const defaultState = { valueType: filterBuilderValueTypes.SERIES }, { - name: 'protocol', + name: 'protocols', label: () => translate('Protocol'), type: filterBuilderTypes.EQUAL, valueType: filterBuilderValueTypes.PROTOCOL diff --git a/src/NzbDrone.Core/Blocklisting/BlocklistRepository.cs b/src/NzbDrone.Core/Blocklisting/BlocklistRepository.cs index ee8328661..11c8243b1 100644 --- a/src/NzbDrone.Core/Blocklisting/BlocklistRepository.cs +++ b/src/NzbDrone.Core/Blocklisting/BlocklistRepository.cs @@ -12,7 +12,7 @@ namespace NzbDrone.Core.Blocklisting List BlocklistedByTorrentInfoHash(int seriesId, string torrentInfoHash); List BlocklistedBySeries(int seriesId); void DeleteForSeriesIds(List seriesIds); - PagingSpec GetPaged(PagingSpec pagingSpec, DownloadProtocol? protocol); + PagingSpec GetPaged(PagingSpec pagingSpec, DownloadProtocol[] protocols); } public class BlocklistRepository : BasicRepository, IBlocklistRepository @@ -42,24 +42,24 @@ namespace NzbDrone.Core.Blocklisting Delete(x => seriesIds.Contains(x.SeriesId)); } - public PagingSpec GetPaged(PagingSpec pagingSpec, DownloadProtocol? protocol) + public PagingSpec GetPaged(PagingSpec pagingSpec, DownloadProtocol[] protocols) { - pagingSpec.Records = GetPagedRecords(PagedBuilder(protocol), pagingSpec, PagedQuery); + pagingSpec.Records = GetPagedRecords(PagedBuilder(protocols), pagingSpec, PagedQuery); var countTemplate = $"SELECT COUNT(*) FROM (SELECT /**select**/ FROM \"{TableMapping.Mapper.TableNameMapping(typeof(Blocklist))}\" /**join**/ /**innerjoin**/ /**leftjoin**/ /**where**/ /**groupby**/ /**having**/) AS \"Inner\""; - pagingSpec.TotalRecords = GetPagedRecordCount(PagedBuilder(protocol).Select(typeof(Blocklist)), pagingSpec, countTemplate); + pagingSpec.TotalRecords = GetPagedRecordCount(PagedBuilder(protocols).Select(typeof(Blocklist)), pagingSpec, countTemplate); return pagingSpec; } - private SqlBuilder PagedBuilder(DownloadProtocol? protocol) + private SqlBuilder PagedBuilder(DownloadProtocol[] protocols) { var builder = Builder() .Join((b, m) => b.SeriesId == m.Id); - if (protocol != null) + if (protocols is { Length: > 0 }) { - builder.Where($"({BuildProtocolWhereClause(protocol)})"); + builder.Where($"({BuildProtocolWhereClause(protocols)})"); } return builder; @@ -72,7 +72,16 @@ namespace NzbDrone.Core.Blocklisting return blocklist; }); - private string BuildProtocolWhereClause(DownloadProtocol? protocol) => - $"\"{TableMapping.Mapper.TableNameMapping(typeof(Blocklist))}\".\"Protocol\" = {(int)protocol}"; + private string BuildProtocolWhereClause(DownloadProtocol[] protocols) + { + var clauses = new List(); + + foreach (var protocol in protocols) + { + clauses.Add($"\"{TableMapping.Mapper.TableNameMapping(typeof(Blocklist))}\".\"Protocol\" = {(int)protocol}"); + } + + return $"({string.Join(" OR ", clauses)})"; + } } } diff --git a/src/NzbDrone.Core/Blocklisting/BlocklistService.cs b/src/NzbDrone.Core/Blocklisting/BlocklistService.cs index 74f20c0fa..137a5dc32 100644 --- a/src/NzbDrone.Core/Blocklisting/BlocklistService.cs +++ b/src/NzbDrone.Core/Blocklisting/BlocklistService.cs @@ -16,7 +16,7 @@ namespace NzbDrone.Core.Blocklisting { bool Blocklisted(int seriesId, ReleaseInfo release); bool BlocklistedTorrentHash(int seriesId, string hash); - PagingSpec Paged(PagingSpec pagingSpec, DownloadProtocol? protocol); + PagingSpec Paged(PagingSpec pagingSpec, DownloadProtocol[] protocols); void Block(RemoteEpisode remoteEpisode, string message); void Delete(int id); void Delete(List ids); @@ -66,9 +66,9 @@ namespace NzbDrone.Core.Blocklisting b.TorrentInfoHash.Equals(hash, StringComparison.InvariantCultureIgnoreCase)); } - public PagingSpec Paged(PagingSpec pagingSpec, DownloadProtocol? protocol) + public PagingSpec Paged(PagingSpec pagingSpec, DownloadProtocol[] protocols) { - return _blocklistRepository.GetPaged(pagingSpec, protocol); + return _blocklistRepository.GetPaged(pagingSpec, protocols); } public void Block(RemoteEpisode remoteEpisode, string message) diff --git a/src/Sonarr.Api.V3/Blocklist/BlocklistController.cs b/src/Sonarr.Api.V3/Blocklist/BlocklistController.cs index ef885ce95..5762c272b 100644 --- a/src/Sonarr.Api.V3/Blocklist/BlocklistController.cs +++ b/src/Sonarr.Api.V3/Blocklist/BlocklistController.cs @@ -25,7 +25,7 @@ namespace Sonarr.Api.V3.Blocklist [HttpGet] [Produces("application/json")] - public PagingResource GetBlocklist([FromQuery] PagingRequestResource paging, [FromQuery] int[] seriesIds = null, [FromQuery] DownloadProtocol? protocol = null) + public PagingResource GetBlocklist([FromQuery] PagingRequestResource paging, [FromQuery] int[] seriesIds = null, [FromQuery] DownloadProtocol[] protocols = null) { var pagingResource = new PagingResource(paging); var pagingSpec = pagingResource.MapToPagingSpec("date", SortDirection.Descending); @@ -35,7 +35,7 @@ namespace Sonarr.Api.V3.Blocklist pagingSpec.FilterExpressions.Add(b => seriesIds.Contains(b.SeriesId)); } - return pagingSpec.ApplyToPage(b => _blocklistService.Paged(pagingSpec, protocol), b => BlocklistResourceMapper.MapToResource(b, _formatCalculator)); + return pagingSpec.ApplyToPage(b => _blocklistService.Paged(pagingSpec, protocols), b => BlocklistResourceMapper.MapToResource(b, _formatCalculator)); } [RestDeleteById]