497 lines
19 KiB
C#
497 lines
19 KiB
C#
// Copyright (c) Microsoft Open Technologies, Inc. All rights reserved. See License.md in the project root for license information.
|
|
|
|
using System;
|
|
using System.Collections.Generic;
|
|
using System.Diagnostics;
|
|
using System.Diagnostics.CodeAnalysis;
|
|
using System.Globalization;
|
|
using System.Linq;
|
|
using System.Text;
|
|
using System.Threading.Tasks;
|
|
using Microsoft.AspNet.SignalR.Configuration;
|
|
using Microsoft.AspNet.SignalR.Hosting;
|
|
using Microsoft.AspNet.SignalR.Infrastructure;
|
|
using Microsoft.AspNet.SignalR.Json;
|
|
using Microsoft.AspNet.SignalR.Messaging;
|
|
using Microsoft.AspNet.SignalR.Tracing;
|
|
using Microsoft.AspNet.SignalR.Transports;
|
|
|
|
namespace Microsoft.AspNet.SignalR
|
|
{
|
|
/// <summary>
|
|
/// Represents a connection between client and server.
|
|
/// </summary>
|
|
public abstract class PersistentConnection
|
|
{
|
|
private const string WebSocketsTransportName = "webSockets";
|
|
private static readonly char[] SplitChars = new[] { ':' };
|
|
|
|
private IConfigurationManager _configurationManager;
|
|
private ITransportManager _transportManager;
|
|
private bool _initialized;
|
|
private IServerCommandHandler _serverMessageHandler;
|
|
|
|
public virtual void Initialize(IDependencyResolver resolver, HostContext context)
|
|
{
|
|
if (resolver == null)
|
|
{
|
|
throw new ArgumentNullException("resolver");
|
|
}
|
|
|
|
if (context == null)
|
|
{
|
|
throw new ArgumentNullException("context");
|
|
}
|
|
|
|
if (_initialized)
|
|
{
|
|
return;
|
|
}
|
|
|
|
MessageBus = resolver.Resolve<IMessageBus>();
|
|
JsonSerializer = resolver.Resolve<IJsonSerializer>();
|
|
TraceManager = resolver.Resolve<ITraceManager>();
|
|
Counters = resolver.Resolve<IPerformanceCounterManager>();
|
|
AckHandler = resolver.Resolve<IAckHandler>();
|
|
ProtectedData = resolver.Resolve<IProtectedData>();
|
|
|
|
_configurationManager = resolver.Resolve<IConfigurationManager>();
|
|
_transportManager = resolver.Resolve<ITransportManager>();
|
|
_serverMessageHandler = resolver.Resolve<IServerCommandHandler>();
|
|
|
|
_initialized = true;
|
|
}
|
|
|
|
public bool Authorize(IRequest request)
|
|
{
|
|
return AuthorizeRequest(request);
|
|
}
|
|
|
|
protected virtual TraceSource Trace
|
|
{
|
|
get
|
|
{
|
|
return TraceManager["SignalR.PersistentConnection"];
|
|
}
|
|
}
|
|
|
|
protected IProtectedData ProtectedData { get; private set; }
|
|
|
|
protected IMessageBus MessageBus { get; private set; }
|
|
|
|
protected IJsonSerializer JsonSerializer { get; private set; }
|
|
|
|
protected IAckHandler AckHandler { get; private set; }
|
|
|
|
protected ITraceManager TraceManager { get; private set; }
|
|
|
|
protected IPerformanceCounterManager Counters { get; private set; }
|
|
|
|
protected ITransport Transport { get; private set; }
|
|
|
|
/// <summary>
|
|
/// Gets the <see cref="IConnection"/> for the <see cref="PersistentConnection"/>.
|
|
/// </summary>
|
|
public IConnection Connection
|
|
{
|
|
get;
|
|
private set;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Gets the <see cref="IConnectionGroupManager"/> for the <see cref="PersistentConnection"/>.
|
|
/// </summary>
|
|
public IConnectionGroupManager Groups
|
|
{
|
|
get;
|
|
private set;
|
|
}
|
|
|
|
private string DefaultSignal
|
|
{
|
|
get
|
|
{
|
|
return PrefixHelper.GetPersistentConnectionName(DefaultSignalRaw);
|
|
}
|
|
}
|
|
|
|
private string DefaultSignalRaw
|
|
{
|
|
get
|
|
{
|
|
return GetType().FullName;
|
|
}
|
|
}
|
|
|
|
internal virtual string GroupPrefix
|
|
{
|
|
get
|
|
{
|
|
return PrefixHelper.PersistentConnectionGroupPrefix;
|
|
}
|
|
}
|
|
|
|
/// <summary>
|
|
/// Handles all requests for <see cref="PersistentConnection"/>s.
|
|
/// </summary>
|
|
/// <param name="context">The <see cref="HostContext"/> for the current request.</param>
|
|
/// <returns>A <see cref="Task"/> that completes when the <see cref="PersistentConnection"/> pipeline is complete.</returns>
|
|
/// <exception cref="T:System.InvalidOperationException">
|
|
/// Thrown if connection wasn't initialized.
|
|
/// Thrown if the transport wasn't specified.
|
|
/// Thrown if the connection id wasn't specified.
|
|
/// </exception>
|
|
public virtual Task ProcessRequest(HostContext context)
|
|
{
|
|
if (context == null)
|
|
{
|
|
throw new ArgumentNullException("context");
|
|
}
|
|
|
|
if (!_initialized)
|
|
{
|
|
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ConnectionNotInitialized));
|
|
}
|
|
|
|
if (IsNegotiationRequest(context.Request))
|
|
{
|
|
return ProcessNegotiationRequest(context);
|
|
}
|
|
else if (IsPingRequest(context.Request))
|
|
{
|
|
return ProcessPingRequest(context);
|
|
}
|
|
|
|
Transport = GetTransport(context);
|
|
|
|
if (Transport == null)
|
|
{
|
|
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorUnknownTransport));
|
|
}
|
|
|
|
string connectionToken = context.Request.QueryString["connectionToken"];
|
|
|
|
// If there's no connection id then this is a bad request
|
|
if (String.IsNullOrEmpty(connectionToken))
|
|
{
|
|
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ProtocolErrorMissingConnectionToken));
|
|
}
|
|
|
|
string connectionId = GetConnectionId(context, connectionToken);
|
|
|
|
// Set the transport's connection id to the unprotected one
|
|
Transport.ConnectionId = connectionId;
|
|
|
|
IList<string> signals = GetSignals(connectionId);
|
|
IList<string> groups = AppendGroupPrefixes(context, connectionId);
|
|
|
|
Connection connection = CreateConnection(connectionId, signals, groups);
|
|
|
|
Connection = connection;
|
|
string groupName = PrefixHelper.GetPersistentConnectionGroupName(DefaultSignalRaw);
|
|
Groups = new GroupManager(connection, groupName);
|
|
|
|
Transport.TransportConnected = () =>
|
|
{
|
|
var command = new ServerCommand
|
|
{
|
|
ServerCommandType = ServerCommandType.RemoveConnection,
|
|
Value = connectionId
|
|
};
|
|
|
|
return _serverMessageHandler.SendCommand(command);
|
|
};
|
|
|
|
Transport.Connected = () =>
|
|
{
|
|
return TaskAsyncHelper.FromMethod(() => OnConnected(context.Request, connectionId).OrEmpty());
|
|
};
|
|
|
|
Transport.Reconnected = () =>
|
|
{
|
|
return TaskAsyncHelper.FromMethod(() => OnReconnected(context.Request, connectionId).OrEmpty());
|
|
};
|
|
|
|
Transport.Received = data =>
|
|
{
|
|
Counters.ConnectionMessagesSentTotal.Increment();
|
|
Counters.ConnectionMessagesSentPerSec.Increment();
|
|
return TaskAsyncHelper.FromMethod(() => OnReceived(context.Request, connectionId, data).OrEmpty());
|
|
};
|
|
|
|
Transport.Disconnected = () =>
|
|
{
|
|
return TaskAsyncHelper.FromMethod(() => OnDisconnected(context.Request, connectionId).OrEmpty());
|
|
};
|
|
|
|
return Transport.ProcessRequest(connection).OrEmpty().Catch(Counters.ErrorsAllTotal, Counters.ErrorsAllPerSec);
|
|
}
|
|
|
|
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to catch any exception when unprotecting data.")]
|
|
internal string GetConnectionId(HostContext context, string connectionToken)
|
|
{
|
|
string unprotectedConnectionToken = null;
|
|
|
|
try
|
|
{
|
|
unprotectedConnectionToken = ProtectedData.Unprotect(connectionToken, Purposes.ConnectionToken);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
Trace.TraceInformation("Failed to process connectionToken {0}: {1}", connectionToken, ex);
|
|
}
|
|
|
|
if (String.IsNullOrEmpty(unprotectedConnectionToken))
|
|
{
|
|
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_ConnectionIdIncorrectFormat));
|
|
}
|
|
|
|
var tokens = unprotectedConnectionToken.Split(SplitChars, 2);
|
|
|
|
string connectionId = tokens[0];
|
|
string tokenUserName = tokens.Length > 1 ? tokens[1] : String.Empty;
|
|
string userName = GetUserIdentity(context);
|
|
|
|
if (!String.Equals(tokenUserName, userName, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
throw new InvalidOperationException(String.Format(CultureInfo.CurrentCulture, Resources.Error_UnrecognizedUserIdentity));
|
|
}
|
|
|
|
return connectionId;
|
|
}
|
|
|
|
[SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We want to prevent any failures in unprotecting")]
|
|
internal IList<string> VerifyGroups(HostContext context, string connectionId)
|
|
{
|
|
string groupsToken = context.Request.QueryString["groupsToken"];
|
|
|
|
if (String.IsNullOrEmpty(groupsToken))
|
|
{
|
|
Trace.TraceInformation("The groups token is missing");
|
|
|
|
return ListHelper<string>.Empty;
|
|
}
|
|
|
|
string unprotectedGroupsToken = null;
|
|
|
|
try
|
|
{
|
|
unprotectedGroupsToken = ProtectedData.Unprotect(groupsToken, Purposes.Groups);
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
Trace.TraceInformation("Failed to process groupsToken {0}: {1}", groupsToken, ex);
|
|
}
|
|
|
|
if (String.IsNullOrEmpty(unprotectedGroupsToken))
|
|
{
|
|
return ListHelper<string>.Empty;
|
|
}
|
|
|
|
var tokens = unprotectedGroupsToken.Split(SplitChars, 2);
|
|
|
|
string groupConnectionId = tokens[0];
|
|
string groupsValue = tokens.Length > 1 ? tokens[1] : String.Empty;
|
|
|
|
if (!String.Equals(groupConnectionId, connectionId, StringComparison.OrdinalIgnoreCase))
|
|
{
|
|
return ListHelper<string>.Empty;
|
|
}
|
|
|
|
return JsonSerializer.Parse<string[]>(groupsValue);
|
|
}
|
|
|
|
private IList<string> AppendGroupPrefixes(HostContext context, string connectionId)
|
|
{
|
|
return (from g in OnRejoiningGroups(context.Request, VerifyGroups(context, connectionId), connectionId)
|
|
select GroupPrefix + g).ToList();
|
|
}
|
|
|
|
private Connection CreateConnection(string connectionId, IList<string> signals, IList<string> groups)
|
|
{
|
|
return new Connection(MessageBus,
|
|
JsonSerializer,
|
|
DefaultSignal,
|
|
connectionId,
|
|
signals,
|
|
groups,
|
|
TraceManager,
|
|
AckHandler,
|
|
Counters,
|
|
ProtectedData);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns the default signals for the <see cref="PersistentConnection"/>.
|
|
/// </summary>
|
|
/// <param name="connectionId">The id of the incoming connection.</param>
|
|
/// <returns>The default signals for this <see cref="PersistentConnection"/>.</returns>
|
|
private IList<string> GetDefaultSignals(string connectionId)
|
|
{
|
|
// The list of default signals this connection cares about:
|
|
// 1. The default signal (the type name)
|
|
// 2. The connection id (so we can message this particular connection)
|
|
// 3. Ack signal
|
|
|
|
return new string[] {
|
|
DefaultSignal,
|
|
PrefixHelper.GetConnectionId(connectionId),
|
|
PrefixHelper.GetAck(connectionId)
|
|
};
|
|
}
|
|
|
|
/// <summary>
|
|
/// Returns the signals used in the <see cref="PersistentConnection"/>.
|
|
/// </summary>
|
|
/// <param name="connectionId">The id of the incoming connection.</param>
|
|
/// <returns>The signals used for this <see cref="PersistentConnection"/>.</returns>
|
|
protected virtual IList<string> GetSignals(string connectionId)
|
|
{
|
|
return GetDefaultSignals(connectionId);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called before every request and gives the user a authorize the user.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <returns>A boolean value that represents if the request is authorized.</returns>
|
|
protected virtual bool AuthorizeRequest(IRequest request)
|
|
{
|
|
return true;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called when a connection reconnects after a timeout to determine which groups should be rejoined.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <param name="groups">The groups the calling connection claims to be part of.</param>
|
|
/// <param name="connectionId">The id of the reconnecting client.</param>
|
|
/// <returns>A collection of group names that should be joined on reconnect</returns>
|
|
protected virtual IList<string> OnRejoiningGroups(IRequest request, IList<string> groups, string connectionId)
|
|
{
|
|
return groups;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called when a new connection is made.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <param name="connectionId">The id of the connecting client.</param>
|
|
/// <returns>A <see cref="Task"/> that completes when the connect operation is complete.</returns>
|
|
protected virtual Task OnConnected(IRequest request, string connectionId)
|
|
{
|
|
return TaskAsyncHelper.Empty;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called when a connection reconnects after a timeout.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <param name="connectionId">The id of the re-connecting client.</param>
|
|
/// <returns>A <see cref="Task"/> that completes when the re-connect operation is complete.</returns>
|
|
protected virtual Task OnReconnected(IRequest request, string connectionId)
|
|
{
|
|
return TaskAsyncHelper.Empty;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called when data is received from a connection.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <param name="connectionId">The id of the connection sending the data.</param>
|
|
/// <param name="data">The payload sent to the connection.</param>
|
|
/// <returns>A <see cref="Task"/> that completes when the receive operation is complete.</returns>
|
|
protected virtual Task OnReceived(IRequest request, string connectionId, string data)
|
|
{
|
|
return TaskAsyncHelper.Empty;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Called when a connection disconnects.
|
|
/// </summary>
|
|
/// <param name="request">The <see cref="IRequest"/> for the current connection.</param>
|
|
/// <param name="connectionId">The id of the disconnected connection.</param>
|
|
/// <returns>A <see cref="Task"/> that completes when the disconnect operation is complete.</returns>
|
|
protected virtual Task OnDisconnected(IRequest request, string connectionId)
|
|
{
|
|
return TaskAsyncHelper.Empty;
|
|
}
|
|
|
|
private Task ProcessPingRequest(HostContext context)
|
|
{
|
|
var payload = new
|
|
{
|
|
Response = "pong"
|
|
};
|
|
|
|
if (!String.IsNullOrEmpty(context.Request.QueryString["callback"]))
|
|
{
|
|
return ProcessJsonpRequest(context, payload);
|
|
}
|
|
|
|
context.Response.ContentType = JsonUtility.JsonMimeType;
|
|
return context.Response.End(JsonSerializer.Stringify(payload));
|
|
}
|
|
|
|
private Task ProcessNegotiationRequest(HostContext context)
|
|
{
|
|
// Total amount of time without a keep alive before the client should attempt to reconnect in seconds.
|
|
var keepAliveTimeout = _configurationManager.KeepAliveTimeout();
|
|
string connectionId = Guid.NewGuid().ToString("d");
|
|
string connectionToken = connectionId + ':' + GetUserIdentity(context);
|
|
|
|
var payload = new
|
|
{
|
|
Url = context.Request.Url.LocalPath.Replace("/negotiate", ""),
|
|
ConnectionToken = ProtectedData.Protect(connectionToken, Purposes.ConnectionToken),
|
|
ConnectionId = connectionId,
|
|
KeepAliveTimeout = keepAliveTimeout != null ? keepAliveTimeout.Value.TotalSeconds : (double?)null,
|
|
DisconnectTimeout = _configurationManager.DisconnectTimeout.TotalSeconds,
|
|
TryWebSockets = _transportManager.SupportsTransport(WebSocketsTransportName) && context.SupportsWebSockets(),
|
|
WebSocketServerUrl = context.WebSocketServerUrl(),
|
|
ProtocolVersion = "1.2"
|
|
};
|
|
|
|
if (!String.IsNullOrEmpty(context.Request.QueryString["callback"]))
|
|
{
|
|
return ProcessJsonpRequest(context, payload);
|
|
}
|
|
|
|
context.Response.ContentType = JsonUtility.JsonMimeType;
|
|
return context.Response.End(JsonSerializer.Stringify(payload));
|
|
}
|
|
|
|
private static string GetUserIdentity(HostContext context)
|
|
{
|
|
if (context.Request.User != null && context.Request.User.Identity.IsAuthenticated)
|
|
{
|
|
return context.Request.User.Identity.Name ?? String.Empty;
|
|
}
|
|
return String.Empty;
|
|
}
|
|
|
|
private Task ProcessJsonpRequest(HostContext context, object payload)
|
|
{
|
|
context.Response.ContentType = JsonUtility.JavaScriptMimeType;
|
|
var data = JsonUtility.CreateJsonpCallback(context.Request.QueryString["callback"], JsonSerializer.Stringify(payload));
|
|
|
|
return context.Response.End(data);
|
|
}
|
|
|
|
private static bool IsNegotiationRequest(IRequest request)
|
|
{
|
|
return request.Url.LocalPath.EndsWith("/negotiate", StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
private static bool IsPingRequest(IRequest request)
|
|
{
|
|
return request.Url.LocalPath.EndsWith("/ping", StringComparison.OrdinalIgnoreCase);
|
|
}
|
|
|
|
private ITransport GetTransport(HostContext context)
|
|
{
|
|
return _transportManager.GetTransport(context);
|
|
}
|
|
}
|
|
}
|