Is there a way for TFF clients to have internal states? - tensorflow-federated

The code in the TFF tutorials and in the research projects I see generally only keep track of server states. I’d like there to be internal client states (for instance, additional client internal neural networks which are completely decentralized and don’t update in a federated manner) that would influence the federated client computations.
However, in the client computations I have seen, they are only functions of the server states and the data. Is it possible to accomplish the above?

Yup, this is easy to express in TFF, and will execution just fine in the default execution stacks.
As you've noticed, the TFF repository generally has examples of cross-device Federated Learning (Kairouz et. al 2019). Generally we talk about the state have tff.SERVER placement, and the function signature for one "round" of federated learning has the structure (for details about TFF's type shorthand, see the Federated data section of the tutorials):
(<State#SERVER, {Dataset}#CLIENTS> -> State#Server)
We can represent stateful client by simply extending the signature:
(<State#SERVER, {State}#Clients, {Dataset}#CLIENTS> -> <State#Server, {State}#Clients>)
Implementing a version of Federated Averaging (McMahan et. al 2016) that includes a client state object might look something like:
#tff.tf_computation(
model_type,
client_state_type, # additional state parameter
client_data_type)
def client_training_fn(model, state, dataset):
model_update, new_state = # do some local training
return model_update, new_state # return a tuple including updated state
#tff.federated_computation(
tff.FederatedType(server_state_type, tff.SERVER),
tff.FederatedType(client_state_type , tff.CLIENTS), # new parameter for state
tff.FederatedType(client_data_type , tff.CIENTS))
def run_fed_avg(server_state, client_states, client_datasets):
client_initial_models = tff.federated_broadcast(server_state.model)
client_updates, new_client_state = tff.federated_map(client_training_fn,
# Pass the client states as an argument.
(client_initial_models, client_states, client_datasets))
average_update = tff.federated_mean(client_updates)
new_server_state = tff.federated_map(server_update_fn, (server_state, average_update))
# Make sure to return the client states so they can be used in later rounds.
return new_server_state, new_client_states
The invocation of run_fed_avg would require passing a Python list of tensors/structures for each client participating in a round, and the result fo the method invocation will be the server state, and a list of client states.

Related

Saxon - s9api - setParameter as node and access in transformation

we are trying to add parameters to a transformation at the runtime. The only possible way to do so, is to set every single parameter and not a node. We don't know yet how to create a node for the setParameter.
Current setParameter:
QName TEST XdmAtomicValue 24
Expected setParameter:
<TempNode> <local>Value1</local> </TempNode>
We searched and tried to create a XdmNode and XdmItem.
If you want to create an XdmNode by parsing XML, the best way to do it is:
DocumentBuilder db = processor.newDocumentBuilder();
XdmNode node = db.build(new StreamSource(
new StringReader("<doc><elem/></doc>")));
You could also pass a string containing lexical XML as the parameter value, and then convert it to a tree by calling the XPath parse-xml() function.
If you want to construct the XdmNode programmatically, there are a number of options:
DocumentBuilder.newBuildingStreamWriter() gives you an instance of BuildingStreamWriter which extends XmlStreamWriter, and you can create the document by writing events to it using methods such as writeStartElement, writeCharacters, writeEndElement; at the end call getDocumentNode() on the BuildingStreamWriter, which gives you an XdmNode. This has the advantage that XmlStreamWriter is a standard API, though it's not actually a very nice one, because the documentation isn't very good and as a result implementations vary in their behaviour.
Another event-based API is Saxon's Push class; this differs from most push-based event APIs in that rather than having a flat sequence of methods like:
builder.startElement('x');
builder.characters('abc');
builder.endElement();
you have a nested sequence:
Element x = Document.elem('x');
x.text('abc');
x.close();
As mentioned by Martin, there is the "sapling" API: Saplings.doc().withChild(elem(...).withChild(elem(...)) etc. This API is rather radically different from anything you might be familiar with (though it's influenced by the LINQ API for tree construction on .NET) but once you've got used to it, it reads very well. The Sapling API constructs a very light-weight tree in memory (hance the name), and converts it to a fully-fledged XDM tree with a final call of SaplingDocument.toXdmNode().
If you're familiar with DOM, JDOM2, or XOM, you can construct a tree using any of those libraries and then convert it for use by Saxon. That's a bit convoluted and only really intended for applications that are already using a third-party tree model heavily (or for users who love these APIs and prefer them to anything else).
In the Saxon Java s9api, you can construct temporary trees as SaplingNode/SaplingElement/SaplingDocument, see https://www.saxonica.com/html/documentation12/javadoc/net/sf/saxon/sapling/SaplingDocument.html and https://www.saxonica.com/html/documentation12/javadoc/net/sf/saxon/sapling/SaplingElement.html.
To give you a simple example constructing from a Map, as you seem to want to do:
Processor processor = new Processor();
Map<String, String> xsltParameters = new HashMap<>();
xsltParameters.put("foo", "value 1");
xsltParameters.put("bar", "value 2");
SaplingElement saplingElement = new SaplingElement("Test");
for (Map.Entry<String, String> param : xsltParameters.entrySet())
{
saplingElement = saplingElement.withChild(new SaplingElement(param.getKey()).withText(param.getValue()));
}
XdmNode paramNode = saplingElement.toXdmNode(processor);
System.out.println(paramNode);
outputs e.g. <Test><bar>value 2</bar><foo>value 1</foo></Test>.
So the key is to understand that withChild() returns a new SaplingElement.
The code can be compacted using streams e.g.
XdmNode paramNode2 = Saplings.elem("root").withChild(
xsltParameters
.entrySet()
.stream()
.map(p -> Saplings.elem(p.getKey()).withText(p.getValue()))
.collect(Collectors.toList())
.toArray(SaplingElement[]::new))
.toXdmNode(processor);
System.out.println(paramNode2);

Can I visualize a Multibody pose without explicitly calculating every body's full transform?

In the examples/quadrotor/ example, a custom QuadrotorPlant is specified and its output is passed into QuadrotorGeometry where the QuadrotorPlant state is packaged into FramePoseVector for the SceneGraph to visualize.
The relevant code segment in QuadrotorGeometry that does this:
...
builder->Connect(
quadrotor_geometry->get_output_port(0),
scene_graph->get_source_pose_port(quadrotor_geometry->source_id_));
...
void QuadrotorGeometry::OutputGeometryPose(
const systems::Context<double>& context,
geometry::FramePoseVector<double>* poses) const {
DRAKE_DEMAND(frame_id_.is_valid());
const auto& state = get_input_port(0).Eval(context);
math::RigidTransformd pose(
math::RollPitchYawd(state.segment<3>(3)),
state.head<3>());
*poses = {{frame_id_, pose.GetAsIsometry3()}};
}
In my case, I have a floating based multibody system (think a quadrotor with a pendulum attached) of which I've created a custom plant (LeafSystem). The minimal coordinates for such a system would be 4 (quaternion) + 3 (x,y,z) + 1 (joint angle) = 7. If I were to follow the QuadrotorGeometry example, I believe I would need to specify the full RigidTransformd for the quadrotor and the full RigidTransformd of the pendulum.
Question
Is it possible to set up the visualization / specify the pose such that I only need to specify the 7 (pose of quadrotor + joint angle) state minimal coordinates and have the internal MultibodyPlant handle the computation of each individual body's (quadrotor and pendulum) full RigidTransform which can then be passed to the SceneGraph for visualization?
I believe this was possible with the "attic-ed" (which I take to mean "to be deprecated") RigidBodyTree, which was accomplished in examples/compass_gait
lcm::DrakeLcm lcm;
auto publisher = builder.AddSystem<systems::DrakeVisualizer>(*tree, &lcm);
publisher->set_name("publisher");
builder.Connect(compass_gait->get_floating_base_state_output_port(),
publisher->get_input_port(0));
Where get_floating_base_state_output_port() was outputting the CompassGait state with only 7 states (3 rpy + 3 xyz + 1 hip angle).
What is the MultibodyPlant, SceneGraph equivalent of this?
Update (Using MultibodyPositionToGeometryPose from Russ's deleted answer
I created the following function which, attempts to create a MultibodyPlant from the given model_file and connects the given plant pose_output_port through MultibodyPositionToGeometryPose.
The pose_output_port I'm using is the 4(quaternion) + 3(xyz) + 1(joint angle) minimal state.
void add_plant_visuals(
systems::DiagramBuilder<double>* builder,
geometry::SceneGraph<double>* scene_graph,
const std::string model_file,
const systems::OutputPort<double>& pose_output_port)
{
multibody::MultibodyPlant<double> mbp;
multibody::Parser parser(&mbp, scene_graph);
auto model_id = parser.AddModelFromFile(model_file);
mbp.Finalize();
auto source_id = *mbp.get_source_id();
auto multibody_position_to_geometry_pose = builder->AddSystem<systems::rendering::MultibodyPositionToGeometryPose<double>>(mbp);
builder->Connect(pose_output_port,
multibody_position_to_geometry_pose->get_input_port());
builder->Connect(
multibody_position_to_geometry_pose->get_output_port(),
scene_graph->get_source_pose_port(source_id));
geometry::ConnectDrakeVisualizer(builder, *scene_graph);
}
The above fails with the following exception
abort: Failure at multibody/plant/multibody_plant.cc:2015 in get_geometry_poses_output_port(): condition 'geometry_source_is_registered()' failed.
So, there's a lot in here. I have a suspicion there's a simple answer, but we may have to converge on it.
First, my assumptions:
You've got an "internal" MultibodyPlant (MBP). Presumably, you also have a context for it, allowing you to perform meaningful state-dependent calculations.
Furthermore, I presume the MBP was responsible for registering the geometry (probably happened when you parsed it).
Your LeafSystem will directly connect to the SceneGraph to provide poses.
Given your state, you routinely set the state in the MBP's context to do that evaluation.
Option 1 (Edited):
In your custom LeafSystem, create the FramePoseVector output port, create the calc callback for it, and inside that callback, simply invoke the Eval() of the pose output port of the internal MBP that your LeafSystem own (having previously set the state in your locally owned Context for the MBP and passing in the pointer to the FramePoseVector that your LeafSystem's callback was provided with).
Essentially (in a very coarse way):
MySystem::MySystem() {
this->DeclareAbstractOutputPort("geometry_pose",
&MySystem::OutputGeometryPose);
}
void MySystem::OutputGeometryPose(
const Context& context, FramePoseVector* poses) const {
mbp_context_.get_mutable_continuous_state()
.SetFromVector(my_state_vector);
mbp_.get_geometry_poses_output_port().Eval(mpb_context_, poses);
}
Option 2:
Rather than implementing a LeafSystem that has an internal plant, you could have a Diagram that contains an MBP and exports the MBP's FramePoseVector output directly through the diagram to connect.
This answer addresses, specifically, your edit where you are attempting to use the MultibodyPositionToGeometryPose approach. It doesn't address the larger design issues.
Your problem is that the MultibodyPositiontToGeometryPose system takes a reference to an MBP and keeps a reference to that same MBP. That means the MBP must be alive and well for at least as long as the MPTGP is. However, in your code snippet, your MBP is local to the add_plant_visuals() function so it is destroyed as soon as the function is over.
You'll need to create something that is persisted and owned by someone else.
(This is tightly related to my option 2 - now edited for improved clarity.)

Implementing Jena Dataset provider over MongoDB

I have started to implement a set of classes that provide a direct interface to MongoDB for persistence, similar in spirit to the now-unmaintained SDB persistor implementation for RDBMS.
I am using the time-honored technique of creating the necessary concrete classes from the interfaces and doing a println in each method, therein allowing me to trace the execution. I have gotten all the way to where the engine is calling out to my cursor set up:
public ExtendedIterator<Triple> find(Node s, Node p, Node o) {
System.out.println("+++ MongoGraph:extenditer:find(" + s + p + o + ")");
// TBD Need to turn s,p,o into a match expression! Easy!
MongoCursor cur = this.coll.find().iterator();
ExtendedIterator<Triple> curs = new JenaMongoCursorIterator(cur);
return curs;
}
Sadly, when I later call this:
while(rs.hasNext()) {
QuerySolution soln = rs.nextSolution() ;
System.out.println(soln);
}
It turns out rs.hasNext() is always false even though material is present in the MongoCursor (I can debug-print it in the find() method). Also, the trace print in the next() function in my concrete iterator JenaMongoCursorIterator (which extends NiceIterator which I believe is OK) is never hit. In short, the basic setup seems good but then the engine never cranks the iterator on find()
Trying to use SDB as a guide is completely overwhelming for someone not intimately familiar with the software architecture. It's fully factored and filled with interfaces and factories and although that is excellent, it is difficult to nav.
Has anyone tried to create their own persistor implementation and if so, what are the basic steps to getting a "hello world" running? Hello World in this case is ANY implementation, non-optimized, that can call next() on something to produce a Triple.
TLDR: It is now working.
I was coding too eagerly and JenaMongoCursorIterator contained a method hasNexl which of course did not override hasNext (with a t ) in the default implementation of NiceIterator which returns false.
This is the sort of problem that eclipse and visual debugging and tracing makes a lot easier to resolve than regular jdb. jdb is fine if you know the software architecture pretty well but if you don't, the multiple open source files and being able to mouse over vars and such provides a tremendous boost in the amount of context that can be created to home in on the problem.

State garbage collection in Beam with GlobalWindow

Apache Beam has recently introduced state cells, through StateSpec and the #StateId annotation, with partial support in Apache Flink and Google Cloud Dataflow.
I cannot find any documentation on what happens when this is used with a GlobalWindow. In particular, is there a way to have a "state garbage collection" mechanism to get rid of states for keys that have not been seen for a while according to some configuration, while still maintaining a single all-time state for keys are that seen frequently enough?
Or, is the amount of state used in this case going to diverge, with no way to ever reclaim state corresponding to keys that have not been seen in a while?
I am also interested in whether a potential solution would be supported in either Apache Flink or Google Cloud Dataflow.
Flink and direct runners seem to have some code for "state GC" but I am not really sure what it does and whether it is relevant when using a global window.
State can be automatically garbage collected by a Beam runner at some point after a window expires - when the input watermark exceeds the end of the window by the allowed lateness, so all further input is droppable. The exact details depend on the runner.
As you correctly determined, the Global window may never expire. Then this automatic collection of state will not be invoked. For bounded data, including drain scenarios, it actually will expire, but for a perpetual unbounded data source it will not.
If you are doing stateful processing on such data in the Global window you can use user-defined timers (used through #TimerId, #OnTimer, and TimerSpec - I haven't blogged about these yet) to clear state after some timeout of your choosing. If the state represents an aggregation of some sort, then you'll want a timer anyhow to make sure your data is not stranded in state.
Here is a quick example of their use:
new DoFn<Foo, Baz>() {
private static final String MY_TIMER = "my-timer";
private static final String MY_STATE = "my-state";
#StateId(MY_STATE)
private final StateSpec<ValueState<Bizzle>> =
StateSpec.value(Bizzle.coder());
#TimerId(MY_TIMER)
private final TimerSpec myTimer =
TimerSpecs.timer(TimeDomain.EVENT_TIME);
#ProcessElement
public void process(
ProcessContext c,
#StateId(MY_STATE) ValueState<Bizzle> bizzleState,
#TimerId(MY_TIMER) Timer myTimer) {
bizzleState.write(...);
myTimer.setForNowPlus(...);
}
#OnTimer(MY_TIMER)
public void onMyTimer(
OnTimerContext context,
#StateId(MY_STATE) ValueState<Bizzle> bizzleState) {
context.output(... bizzleState.read() ...);
bizzleState.clear();
}
}
There is not automatic garbage collection of state if you use GlobalWindows. Only if you use some non-global window will state be garbage collected after the watermark passes the end of a window plus the allowed lateness.
What you can do if you must work with GlobalWindows is to manually keep as state the last update timestamp. Then you would periodically set a timer where you check this timestamp against the current time and delete state if necessary. You would set this timer when encountering a key for the first time (which you can see from the absence of your timestamp state) and then re-set it in the #OnTimer method.

Is it possible to keep track of the state when processing logs by DataFlow?

Is it possible to have the DataFlow process maintain the state. There are log processing tools that allow for that by providing fast access (propriety / in-memory) files available for the real time process to keep track of the state on the logs while processing them.
A use case example would be with tracking registration steps taken by users. The registration steps would come in different logs and the data form those logs would be assembled by the real time process into one final database record (for each registered user) that is written to a database.
Can my DataFLow code keep track of the many registration steps (streaming input) by users and once user's registration steps are completed then have the DataFLow process write the records to the database (one record per user).
I don't know much about DataFlow architecture. It must be using some (proprietary / in-memory nosql) data storage for keeping track of things it needs to keep track of (ex. when it tries to produce top 100 customers). Is that fast access data storage also available to the DataFlow processes to use?
Thanks
As danielm said, state is not yet exposed. The good news is you may not need it for your use case.
If you have a PCollection<KV<UserId, LogEvent>> you can use a CombineFn and Combine.perKey to take all of the LogEvents for a specific UserId and combine them into a single output. The CombineFn tells Dataflow how to create an accumulator, update it by incorporating input elements, and then extract a final output. Transforms like Top actually use a CombineFn (with a Heap as the accumulator) rather than an actual state API.
If your events are of different types, you can still do something like this. For instance, if you have two logs, you can do:
PCollection<KV<UserId, LogEvent1>> events1 = ...;
PCollection<KV<UserId, LogEvent2>> events2 = ...;
// Create tuple tags for the value types in each collection.
final TupleTag<LogEvent1> tag1 = new TupleTag<LogEvent1>();
final TupleTag<LogEvent2> tag2 = new TupleTag<LogEvent2>();
//Merge collection values into a CoGbkResult collection
PCollection<KV<UserIf, CoGbkResult>> coGbkResultCollection =
KeyedPCollectionTuple.of(tag1, pt1)
.and(tag2, pt2)
.apply(CoGroupByKey.<UserId>create());
// Access results and do something.
PCollection<T> finalResultCollection =
coGbkResultCollection.apply(ParDo.of(
new DoFn<KV<K, CoGbkResult>, T>() {
#Override
public void processElement(ProcessContext c) {
KV<K, CoGbkResult> e = c.element();
// Get all LogEvent1 values
Iterable<LogEvent1> event1s = e.getValue().getAll(tag1);
// There will only be one LogEvent2
LogEvent2 event2 = e.getValue().getOnly(tag2);
... Do Something to compute T ....
c.output(...some T...);
}
}));
The above example was adapted from docs on CoGroupByKey which have information.
Dataflow does not currently expose the underlying state mechanism that it uses. However, this is definitely on the radar for a future update.

Resources